1 //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector transfer operations to SCF.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include <type_traits>
14 
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 
17 #include "../PassDetail.h"
18 #include "mlir/Dialect/Affine/IR/AffineOps.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/Dialect/MemRef/IR/MemRef.h"
21 #include "mlir/Dialect/SCF/IR/SCF.h"
22 #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/Pass/Pass.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "mlir/Transforms/Passes.h"
28 
29 using namespace mlir;
30 using vector::TransferReadOp;
31 using vector::TransferWriteOp;
32 
33 namespace {
34 
35 /// Attribute name used for labeling transfer ops during progressive lowering.
36 static const char kPassLabel[] = "__vector_to_scf_lowering__";
37 
38 /// Patterns that inherit from this struct have access to
39 /// VectorTransferToSCFOptions.
40 template <typename OpTy>
41 struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon4d9edda10111::VectorToSCFPattern42   explicit VectorToSCFPattern(MLIRContext *context,
43                               VectorTransferToSCFOptions opt)
44       : OpRewritePattern<OpTy>(context), options(opt) {}
45 
46   VectorTransferToSCFOptions options;
47 };
48 
49 /// Given a vector transfer op, calculate which dimension of the `source`
50 /// memref should be unpacked in the next application of TransferOpConversion.
51 /// A return value of None indicates a broadcast.
52 template <typename OpTy>
unpackedDim(OpTy xferOp)53 static Optional<int64_t> unpackedDim(OpTy xferOp) {
54   // TODO: support 0-d corner case.
55   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
56   auto map = xferOp.getPermutationMap();
57   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
58     return expr.getPosition();
59   }
60   assert(xferOp.isBroadcastDim(0) &&
61          "Expected AffineDimExpr or AffineConstantExpr");
62   return None;
63 }
64 
65 /// Compute the permutation map for the new (N-1)-D vector transfer op. This
66 /// map is identical to the current permutation map, but the first result is
67 /// omitted.
68 template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)69 static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70   // TODO: support 0-d corner case.
71   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
72   auto map = xferOp.getPermutationMap();
73   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
74                         b.getContext());
75 }
76 
77 /// Calculate the indices for the new vector transfer op.
78 ///
79 /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
80 ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
81 ///                                 ^^^^^^
82 ///              `iv` is the iteration variable of the (new) surrounding loop.
83 template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)84 static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
85                            SmallVector<Value, 8> &indices) {
86   typename OpTy::Adaptor adaptor(xferOp);
87   // Corresponding memref dim of the vector dim that is unpacked.
88   auto dim = unpackedDim(xferOp);
89   auto prevIndices = adaptor.getIndices();
90   indices.append(prevIndices.begin(), prevIndices.end());
91 
92   Location loc = xferOp.getLoc();
93   bool isBroadcast = !dim.has_value();
94   if (!isBroadcast) {
95     AffineExpr d0, d1;
96     bindDims(xferOp.getContext(), d0, d1);
97     Value offset = adaptor.getIndices()[dim.value()];
98     indices[dim.value()] =
99         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
100   }
101 }
102 
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)103 static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
104                             Value value) {
105   if (hasRetVal) {
106     assert(value && "Expected non-empty value");
107     b.create<scf::YieldOp>(loc, value);
108   } else {
109     b.create<scf::YieldOp>(loc);
110   }
111 }
112 
113 /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
114 /// is set to true. No such check is generated under following circumstances:
115 /// * xferOp does not have a mask.
116 /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
117 ///   computed and attached to the new transfer op in the pattern.)
118 /// * The to-be-unpacked dim of xferOp is a broadcast.
119 template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)120 static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
121   if (!xferOp.getMask())
122     return Value();
123   if (xferOp.getMaskType().getRank() != 1)
124     return Value();
125   if (xferOp.isBroadcastDim(0))
126     return Value();
127 
128   Location loc = xferOp.getLoc();
129   return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
130 }
131 
132 /// Helper function TransferOpConversion and TransferOp1dConversion.
133 /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
134 /// specified dimension `dim` with the loop iteration variable `iv`.
135 /// E.g., when unpacking dimension 0 from:
136 /// ```
137 /// %vec = vector.transfer_read %A[%a, %b] %cst
138 ///     : vector<5x4xf32>, memref<?x?xf32>
139 /// ```
140 /// An if check similar to this will be generated inside the loop:
141 /// ```
142 /// %d = memref.dim %A, %c0 : memref<?x?xf32>
143 /// if (%a + iv < %d) {
144 ///   (in-bounds case)
145 /// } else {
146 ///   (out-of-bounds case)
147 /// }
148 /// ```
149 ///
150 /// If the transfer is 1D and has a mask, this function generates a more complex
151 /// check also accounts for potentially masked out elements.
152 ///
153 /// This function variant returns the value returned by `inBoundsCase` or
154 /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
155 /// `resultTypes`.
156 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)157 static Value generateInBoundsCheck(
158     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
159     TypeRange resultTypes,
160     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
161     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
162   bool hasRetVal = !resultTypes.empty();
163   Value cond; // Condition to be built...
164 
165   // Condition check 1: Access in-bounds?
166   bool isBroadcast = !dim; // No in-bounds check for broadcasts.
167   Location loc = xferOp.getLoc();
168   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
169   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
170     Value memrefDim =
171         vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
172     AffineExpr d0, d1;
173     bindDims(xferOp.getContext(), d0, d1);
174     Value base = xferOp.getIndices()[*dim];
175     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
176     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
177                                     memrefIdx);
178   }
179 
180   // Condition check 2: Masked in?
181   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
182     if (cond)
183       cond = lb.create<arith::AndIOp>(cond, maskCond);
184     else
185       cond = maskCond;
186   }
187 
188   // If the condition is non-empty, generate an SCF::IfOp.
189   if (cond) {
190     auto check = lb.create<scf::IfOp>(
191         resultTypes, cond,
192         /*thenBuilder=*/
193         [&](OpBuilder &b, Location loc) {
194           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
195         },
196         /*elseBuilder=*/
197         [&](OpBuilder &b, Location loc) {
198           if (outOfBoundsCase) {
199             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
200           } else {
201             b.create<scf::YieldOp>(loc);
202           }
203         });
204 
205     return hasRetVal ? check.getResult(0) : Value();
206   }
207 
208   // Condition is empty, no need for an SCF::IfOp.
209   return inBoundsCase(b, loc);
210 }
211 
212 /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
213 /// a return value. Consequently, this function does not have a return value.
214 template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)215 static void generateInBoundsCheck(
216     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
217     function_ref<void(OpBuilder &, Location)> inBoundsCase,
218     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
219   generateInBoundsCheck(
220       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
221       /*inBoundsCase=*/
222       [&](OpBuilder &b, Location loc) {
223         inBoundsCase(b, loc);
224         return Value();
225       },
226       /*outOfBoundsCase=*/
227       [&](OpBuilder &b, Location loc) {
228         if (outOfBoundsCase)
229           outOfBoundsCase(b, loc);
230         return Value();
231       });
232 }
233 
234 /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)235 static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
236   if (!attr)
237     return attr;
238   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
239 }
240 
241 /// Add the pass label to a vector transfer op if its rank is not the target
242 /// rank.
243 template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)244 static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
245                                 unsigned targetRank) {
246   if (newXferOp.getVectorType().getRank() > targetRank)
247     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
248 }
249 
250 /// Return true if this transfer op operates on a source tensor.
251 template <typename OpTy>
isTensorOp(OpTy xferOp)252 static bool isTensorOp(OpTy xferOp) {
253   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
254     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
255       // TransferWriteOps on tensors have a result.
256       assert(xferOp->getNumResults() > 0);
257     }
258     return true;
259   }
260   return false;
261 }
262 
263 namespace lowering_n_d {
264 
265 /// Helper data structure for data and mask buffers.
266 struct BufferAllocs {
267   Value dataBuffer;
268   Value maskBuffer;
269 };
270 
271 // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)272 static Operation *getAutomaticAllocationScope(Operation *op) {
273   Operation *scope =
274       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
275   assert(scope && "Expected op to be inside automatic allocation scope");
276   return scope;
277 }
278 
279 /// Allocate temporary buffers for data (vector) and mask (if present).
280 template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)281 static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
282   Location loc = xferOp.getLoc();
283   OpBuilder::InsertionGuard guard(b);
284   Operation *scope = getAutomaticAllocationScope(xferOp);
285   assert(scope->getNumRegions() == 1 &&
286          "AutomaticAllocationScope with >1 regions");
287   b.setInsertionPointToStart(&scope->getRegion(0).front());
288 
289   BufferAllocs result;
290   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
291   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
292 
293   if (xferOp.getMask()) {
294     auto maskType = MemRefType::get({}, xferOp.getMask().getType());
295     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
296     b.setInsertionPoint(xferOp);
297     b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
298     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
299   }
300 
301   return result;
302 }
303 
304 /// Given a MemRefType with VectorType element type, unpack one dimension from
305 /// the VectorType into the MemRefType.
306 ///
307 /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)308 static MemRefType unpackOneDim(MemRefType type) {
309   auto vectorType = type.getElementType().dyn_cast<VectorType>();
310   auto memrefShape = type.getShape();
311   SmallVector<int64_t, 8> newMemrefShape;
312   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
313   newMemrefShape.push_back(vectorType.getDimSize(0));
314   return MemRefType::get(newMemrefShape,
315                          VectorType::get(vectorType.getShape().drop_front(),
316                                          vectorType.getElementType()));
317 }
318 
319 /// Given a transfer op, find the memref from which the mask is loaded. This
320 /// is similar to Strategy<TransferWriteOp>::getBuffer.
321 template <typename OpTy>
getMaskBuffer(OpTy xferOp)322 static Value getMaskBuffer(OpTy xferOp) {
323   assert(xferOp.getMask() && "Expected that transfer op has mask");
324   auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
325   assert(loadOp && "Expected transfer op mask produced by LoadOp");
326   return loadOp.getMemRef();
327 }
328 
329 /// Codegen strategy, depending on the operation.
330 template <typename OpTy>
331 struct Strategy;
332 
333 /// Code strategy for vector TransferReadOp.
334 template <>
335 struct Strategy<TransferReadOp> {
336   /// Find the StoreOp that is used for writing the current TransferReadOp's
337   /// result to the temporary buffer allocation.
getStoreOp__anon4d9edda10111::lowering_n_d::Strategy338   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
339     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
340     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
341     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
342     return storeOp;
343   }
344 
345   /// Find the temporary buffer allocation. All labeled TransferReadOps are
346   /// used like this, where %buf is either the buffer allocation or a type cast
347   /// of the buffer allocation:
348   /// ```
349   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
350   /// memref.store %vec, %buf[...] ...
351   /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy352   static Value getBuffer(TransferReadOp xferOp) {
353     return getStoreOp(xferOp).getMemRef();
354   }
355 
356   /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy357   static void getBufferIndices(TransferReadOp xferOp,
358                                SmallVector<Value, 8> &indices) {
359     auto storeOp = getStoreOp(xferOp);
360     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
361     indices.append(prevIndices.begin(), prevIndices.end());
362   }
363 
364   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
365   /// accesses on the to-be-unpacked dimension.
366   ///
367   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
368   ///    variable `iv`.
369   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
370   ///
371   /// E.g.:
372   /// ```
373   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
374   ///     : memref<?x?x?xf32>, vector<4x3xf32>
375   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
376   /// ```
377   /// Is rewritten to:
378   /// ```
379   /// %casted = vector.type_cast %buf
380   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
381   /// for %j = 0 to 4 {
382   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
383   ///       : memref<?x?x?xf32>, vector<3xf32>
384   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
385   /// }
386   /// ```
387   ///
388   /// Note: The loop and type cast are generated in TransferOpConversion.
389   ///       The original TransferReadOp and store op are deleted in `cleanup`.
390   /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy391   static TransferReadOp rewriteOp(OpBuilder &b,
392                                   VectorTransferToSCFOptions options,
393                                   TransferReadOp xferOp, Value buffer, Value iv,
394                                   ValueRange /*loopState*/) {
395     SmallVector<Value, 8> storeIndices;
396     getBufferIndices(xferOp, storeIndices);
397     storeIndices.push_back(iv);
398 
399     SmallVector<Value, 8> xferIndices;
400     getXferIndices(b, xferOp, iv, xferIndices);
401 
402     Location loc = xferOp.getLoc();
403     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
404     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
405     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
406     auto newXferOp = b.create<vector::TransferReadOp>(
407         loc, vecType, xferOp.getSource(), xferIndices,
408         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
409         xferOp.getPadding(), Value(), inBoundsAttr);
410 
411     maybeApplyPassLabel(b, newXferOp, options.targetRank);
412 
413     b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
414     return newXferOp;
415   }
416 
417   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
418   /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy419   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
420                                     Value buffer, Value iv,
421                                     ValueRange /*loopState*/) {
422     SmallVector<Value, 8> storeIndices;
423     getBufferIndices(xferOp, storeIndices);
424     storeIndices.push_back(iv);
425 
426     Location loc = xferOp.getLoc();
427     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
428     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
429     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
430     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
431 
432     return Value();
433   }
434 
435   /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy436   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
437                       scf::ForOp /*forOp*/) {
438     rewriter.eraseOp(getStoreOp(xferOp));
439     rewriter.eraseOp(xferOp);
440   }
441 
442   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy443   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
444 };
445 
446 /// Codegen strategy for vector TransferWriteOp.
447 template <>
448 struct Strategy<TransferWriteOp> {
449   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
450   /// used like this, where %buf is either the buffer allocation or a type cast
451   /// of the buffer allocation:
452   /// ```
453   /// %vec = memref.load %buf[...] ...
454   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
455   /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy456   static Value getBuffer(TransferWriteOp xferOp) {
457     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
458     assert(loadOp && "Expected transfer op vector produced by LoadOp");
459     return loadOp.getMemRef();
460   }
461 
462   /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy463   static void getBufferIndices(TransferWriteOp xferOp,
464                                SmallVector<Value, 8> &indices) {
465     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
466     auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
467     indices.append(prevIndices.begin(), prevIndices.end());
468   }
469 
470   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
471   /// accesses on the to-be-unpacked dimension.
472   ///
473   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
474   ///    using the loop iteration variable `iv`.
475   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
476   ///    to memory.
477   ///
478   /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy479   static TransferWriteOp rewriteOp(OpBuilder &b,
480                                    VectorTransferToSCFOptions options,
481                                    TransferWriteOp xferOp, Value buffer,
482                                    Value iv, ValueRange loopState) {
483     SmallVector<Value, 8> loadIndices;
484     getBufferIndices(xferOp, loadIndices);
485     loadIndices.push_back(iv);
486 
487     SmallVector<Value, 8> xferIndices;
488     getXferIndices(b, xferOp, iv, xferIndices);
489 
490     Location loc = xferOp.getLoc();
491     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
492     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
493     auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
494     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
495     auto newXferOp = b.create<vector::TransferWriteOp>(
496         loc, type, vec, source, xferIndices,
497         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
498         inBoundsAttr);
499 
500     maybeApplyPassLabel(b, newXferOp, options.targetRank);
501 
502     return newXferOp;
503   }
504 
505   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy506   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
507                                     Value buffer, Value iv,
508                                     ValueRange loopState) {
509     return isTensorOp(xferOp) ? loopState[0] : Value();
510   }
511 
512   /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy513   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
514                       scf::ForOp forOp) {
515     if (isTensorOp(xferOp)) {
516       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
517       rewriter.replaceOp(xferOp, forOp->getResult(0));
518     } else {
519       rewriter.eraseOp(xferOp);
520     }
521   }
522 
523   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy524   static Value initialLoopState(TransferWriteOp xferOp) {
525     return isTensorOp(xferOp) ? xferOp.getSource() : Value();
526   }
527 };
528 
529 template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)530 LogicalResult checkPrepareXferOp(OpTy xferOp,
531                                  VectorTransferToSCFOptions options) {
532   if (xferOp->hasAttr(kPassLabel))
533     return failure();
534   if (xferOp.getVectorType().getRank() <= options.targetRank)
535     return failure();
536   if (isTensorOp(xferOp) && !options.lowerTensors)
537     return failure();
538   // Transfer ops that modify the element type are not supported atm.
539   if (xferOp.getVectorType().getElementType() !=
540       xferOp.getShapedType().getElementType())
541     return failure();
542   return success();
543 }
544 
545 /// Prepare a TransferReadOp for progressive lowering.
546 ///
547 /// 1. Allocate a temporary buffer.
548 /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
549 /// 3. Store the result of the TransferReadOp into the temporary buffer.
550 /// 4. Load the result from the temporary buffer and replace all uses of the
551 ///    original TransferReadOp with this load.
552 ///
553 /// E.g.:
554 /// ```
555 /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
556 ///     : vector<5x4xf32>, memref<?x?x?xf32>
557 /// ```
558 /// is rewritten to:
559 /// ```
560 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
561 /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
562 ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
563 /// memref.store %1, %0[] : memref<vector<5x4xf32>>
564 /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
565 /// ```
566 ///
567 /// Note: A second temporary buffer may be allocated for the `mask` operand.
568 struct PrepareTransferReadConversion
569     : public VectorToSCFPattern<TransferReadOp> {
570   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
571 
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferReadConversion572   LogicalResult matchAndRewrite(TransferReadOp xferOp,
573                                 PatternRewriter &rewriter) const override {
574     if (checkPrepareXferOp(xferOp, options).failed())
575       return failure();
576 
577     auto buffers = allocBuffers(rewriter, xferOp);
578     auto *newXfer = rewriter.clone(*xferOp.getOperation());
579     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
580     if (xferOp.getMask()) {
581       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
582           buffers.maskBuffer);
583     }
584 
585     Location loc = xferOp.getLoc();
586     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
587                                      buffers.dataBuffer);
588     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
589 
590     return success();
591   }
592 };
593 
594 /// Prepare a TransferWriteOp for progressive lowering.
595 ///
596 /// 1. Allocate a temporary buffer.
597 /// 2. Store the vector into the buffer.
598 /// 3. Load the vector from the buffer again.
599 /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
600 ///    marking it eligible for progressive lowering via TransferOpConversion.
601 ///
602 /// E.g.:
603 /// ```
604 /// vector.transfer_write %vec, %A[%a, %b, %c]
605 ///     : vector<5x4xf32>, memref<?x?x?xf32>
606 /// ```
607 /// is rewritten to:
608 /// ```
609 /// %0 = memref.alloca() : memref<vector<5x4xf32>>
610 /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
611 /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
612 /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
613 ///     : vector<5x4xf32>, memref<?x?x?xf32>
614 /// ```
615 ///
616 /// Note: A second temporary buffer may be allocated for the `mask` operand.
617 struct PrepareTransferWriteConversion
618     : public VectorToSCFPattern<TransferWriteOp> {
619   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
620 
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferWriteConversion621   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
622                                 PatternRewriter &rewriter) const override {
623     if (checkPrepareXferOp(xferOp, options).failed())
624       return failure();
625 
626     Location loc = xferOp.getLoc();
627     auto buffers = allocBuffers(rewriter, xferOp);
628     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
629                                      buffers.dataBuffer);
630     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
631     rewriter.updateRootInPlace(xferOp, [&]() {
632       xferOp.getVectorMutable().assign(loadedVec);
633       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
634     });
635 
636     if (xferOp.getMask()) {
637       rewriter.updateRootInPlace(xferOp, [&]() {
638         xferOp.getMaskMutable().assign(buffers.maskBuffer);
639       });
640     }
641 
642     return success();
643   }
644 };
645 
646 /// Progressive lowering of vector transfer ops: Unpack one dimension.
647 ///
648 /// 1. Unpack one dimension from the current buffer type and cast the buffer
649 ///    to that new type. E.g.:
650 ///    ```
651 ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
652 ///    vector.transfer_write %vec ...
653 ///    ```
654 ///    The following cast is generated:
655 ///    ```
656 ///    %casted = vector.type_cast %0
657 ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
658 ///    ```
659 /// 2. Generate a for loop and rewrite the transfer op according to the
660 ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
661 ///    out-of-bounds, generate an if-check and handle both cases separately.
662 /// 3. Clean up according to the corresponding Strategy<OpTy>.
663 ///
664 /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
665 /// source (as opposed to a memref source), then each iteration of the generated
666 /// scf.for loop yields the new tensor value. E.g.:
667 /// ```
668 /// %result = scf.for i = 0 to 5 {
669 ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
670 ///   %1 = vector.transfer_write %0, %source[...]
671 ///       : vector<4x3xf32>, tensor<5x4x3xf32>
672 ///   scf.yield %1 : tensor<5x4x3xf32>
673 /// }
674 /// ```
675 template <typename OpTy>
676 struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
677   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
678 
initialize__anon4d9edda10111::lowering_n_d::TransferOpConversion679   void initialize() {
680     // This pattern recursively unpacks one dimension at a time. The recursion
681     // bounded as the rank is strictly decreasing.
682     this->setHasBoundedRewriteRecursion();
683   }
684 
matchAndRewrite__anon4d9edda10111::lowering_n_d::TransferOpConversion685   LogicalResult matchAndRewrite(OpTy xferOp,
686                                 PatternRewriter &rewriter) const override {
687     if (!xferOp->hasAttr(kPassLabel))
688       return failure();
689 
690     // Find and cast data buffer. How the buffer can be found depends on OpTy.
691     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
692     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
693     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
694     auto castedDataType = unpackOneDim(dataBufferType);
695     auto castedDataBuffer =
696         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
697 
698     // If the xferOp has a mask: Find and cast mask buffer.
699     Value castedMaskBuffer;
700     if (xferOp.getMask()) {
701       auto maskBuffer = getMaskBuffer(xferOp);
702       auto maskBufferType =
703           maskBuffer.getType().template dyn_cast<MemRefType>();
704       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
705         // Do not unpack a dimension of the mask, if:
706         // * To-be-unpacked transfer op dimension is a broadcast.
707         // * Mask is 1D, i.e., the mask cannot be further unpacked.
708         //   (That means that all remaining dimensions of the transfer op must
709         //   be broadcasted.)
710         castedMaskBuffer = maskBuffer;
711       } else {
712         auto castedMaskType = unpackOneDim(maskBufferType);
713         castedMaskBuffer =
714             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
715       }
716     }
717 
718     // Loop bounds and step.
719     auto lb = locB.create<arith::ConstantIndexOp>(0);
720     auto ub = locB.create<arith::ConstantIndexOp>(
721         castedDataType.getDimSize(castedDataType.getRank() - 1));
722     auto step = locB.create<arith::ConstantIndexOp>(1);
723     // TransferWriteOps that operate on tensors return the modified tensor and
724     // require a loop state.
725     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
726 
727     // Generate for loop.
728     auto result = locB.create<scf::ForOp>(
729         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
730         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
731           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
732 
733           auto result = generateInBoundsCheck(
734               b, xferOp, iv, unpackedDim(xferOp),
735               stateType ? TypeRange(stateType) : TypeRange(),
736               /*inBoundsCase=*/
737               [&](OpBuilder &b, Location loc) {
738                 // Create new transfer op.
739                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
740                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
741 
742                 // If old transfer op has a mask: Set mask on new transfer op.
743                 // Special case: If the mask of the old transfer op is 1D and
744                 // the
745                 //               unpacked dim is not a broadcast, no mask is
746                 //               needed on the new transfer op.
747                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
748                                          xferOp.getMaskType().getRank() > 1)) {
749                   OpBuilder::InsertionGuard guard(b);
750                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
751 
752                   SmallVector<Value, 8> loadIndices;
753                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
754                   // In case of broadcast: Use same indices to load from memref
755                   // as before.
756                   if (!xferOp.isBroadcastDim(0))
757                     loadIndices.push_back(iv);
758 
759                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
760                                                        loadIndices);
761                   rewriter.updateRootInPlace(newXfer, [&]() {
762                     newXfer.getMaskMutable().assign(mask);
763                   });
764                 }
765 
766                 return loopState.empty() ? Value() : newXfer->getResult(0);
767               },
768               /*outOfBoundsCase=*/
769               [&](OpBuilder &b, Location /*loc*/) {
770                 return Strategy<OpTy>::handleOutOfBoundsDim(
771                     b, xferOp, castedDataBuffer, iv, loopState);
772               });
773 
774           maybeYieldValue(b, loc, !loopState.empty(), result);
775         });
776 
777     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
778     return success();
779   }
780 };
781 
782 } // namespace lowering_n_d
783 
784 namespace lowering_n_d_unrolled {
785 
786 /// If the original transfer op has a mask, compute the mask of the new transfer
787 /// op (for the current iteration `i`) and assign it.
788 template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)789 static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
790                             int64_t i) {
791   if (!xferOp.getMask())
792     return;
793 
794   if (xferOp.isBroadcastDim(0)) {
795     // To-be-unpacked dimension is a broadcast, which does not have a
796     // corresponding mask dimension. Mask attribute remains unchanged.
797     newXferOp.getMaskMutable().assign(xferOp.getMask());
798     return;
799   }
800 
801   if (xferOp.getMaskType().getRank() > 1) {
802     // Unpack one dimension of the mask.
803     OpBuilder::InsertionGuard guard(b);
804     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
805 
806     llvm::SmallVector<int64_t, 1> indices({i});
807     Location loc = xferOp.getLoc();
808     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
809     newXferOp.getMaskMutable().assign(newMask);
810   }
811 
812   // If we end up here: The mask of the old transfer op is 1D and the unpacked
813   // dim is not a broadcast, so no mask is needed on the new transfer op.
814   // `generateInBoundsCheck` will have evaluated the mask already.
815 }
816 
817 /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
818 /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
819 /// memref buffer is allocated and the SCF loop is fully unrolled.
820 ///
821 /// ```
822 /// E.g.:
823 /// ```
824 /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
825 ///     : memref<?x?x?xf32>, vector<5x4xf32>
826 /// ```
827 /// is rewritten to IR such as (simplified):
828 /// ```
829 /// %v_init = splat %padding : vector<5x4xf32>
830 /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
831 ///     : memref<?x?x?xf32>, vector<4xf32>
832 /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
833 /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
834 ///     : memref<?x?x?xf32>, vector<4xf32>
835 /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
836 /// ...
837 /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
838 ///     : memref<?x?x?xf32>, vector<4xf32>
839 /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
840 /// ```
841 ///
842 /// Note: As an optimization, if the result of the original TransferReadOp
843 /// was directly inserted into another vector, no new %v_init vector is created.
844 /// Instead, the new TransferReadOp results are inserted into that vector.
845 struct UnrollTransferReadConversion
846     : public VectorToSCFPattern<TransferReadOp> {
847   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
848 
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion849   void initialize() {
850     // This pattern recursively unpacks one dimension at a time. The recursion
851     // bounded as the rank is strictly decreasing.
852     setHasBoundedRewriteRecursion();
853   }
854 
855   /// Return the vector into which the newly created TransferReadOp results
856   /// are inserted.
getResultVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion857   Value getResultVector(TransferReadOp xferOp,
858                         PatternRewriter &rewriter) const {
859     if (auto insertOp = getInsertOp(xferOp))
860       return insertOp.getDest();
861     Location loc = xferOp.getLoc();
862     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
863                                             xferOp.getPadding());
864   }
865 
866   /// If the result of the TransferReadOp has exactly one user, which is a
867   /// vector::InsertOp, return that operation.
getInsertOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion868   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
869     if (xferOp->hasOneUse()) {
870       Operation *xferOpUser = *xferOp->getUsers().begin();
871       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
872         return insertOp;
873     }
874 
875     return vector::InsertOp();
876   }
877 
878   /// If the result of the TransferReadOp has exactly one user, which is a
879   /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion880   void getInsertionIndices(TransferReadOp xferOp,
881                            SmallVector<int64_t, 8> &indices) const {
882     if (auto insertOp = getInsertOp(xferOp)) {
883       for (Attribute attr : insertOp.getPosition())
884         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
885     }
886   }
887 
888   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
889   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion890   LogicalResult matchAndRewrite(TransferReadOp xferOp,
891                                 PatternRewriter &rewriter) const override {
892     if (xferOp.getVectorType().getRank() <= options.targetRank)
893       return failure();
894     if (isTensorOp(xferOp) && !options.lowerTensors)
895       return failure();
896     // Transfer ops that modify the element type are not supported atm.
897     if (xferOp.getVectorType().getElementType() !=
898         xferOp.getShapedType().getElementType())
899       return failure();
900 
901     auto insertOp = getInsertOp(xferOp);
902     auto vec = getResultVector(xferOp, rewriter);
903     auto vecType = vec.getType().dyn_cast<VectorType>();
904     auto xferVecType = xferOp.getVectorType();
905     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
906                                           xferVecType.getElementType());
907     int64_t dimSize = xferVecType.getShape()[0];
908 
909     // Generate fully unrolled loop of transfer ops.
910     Location loc = xferOp.getLoc();
911     for (int64_t i = 0; i < dimSize; ++i) {
912       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
913 
914       vec = generateInBoundsCheck(
915           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
916           /*inBoundsCase=*/
917           [&](OpBuilder &b, Location loc) {
918             // Indices for the new transfer op.
919             SmallVector<Value, 8> xferIndices;
920             getXferIndices(b, xferOp, iv, xferIndices);
921 
922             // Indices for the new vector.insert op.
923             SmallVector<int64_t, 8> insertionIndices;
924             getInsertionIndices(xferOp, insertionIndices);
925             insertionIndices.push_back(i);
926 
927             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
928             auto newXferOp = b.create<vector::TransferReadOp>(
929                 loc, newXferVecType, xferOp.getSource(), xferIndices,
930                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
931                 xferOp.getPadding(), Value(), inBoundsAttr);
932             maybeAssignMask(b, xferOp, newXferOp, i);
933             return b.create<vector::InsertOp>(loc, newXferOp, vec,
934                                               insertionIndices);
935           },
936           /*outOfBoundsCase=*/
937           [&](OpBuilder &b, Location loc) {
938             // Loop through original (unmodified) vector.
939             return vec;
940           });
941     }
942 
943     if (insertOp) {
944       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
945       rewriter.replaceOp(insertOp, vec);
946       rewriter.eraseOp(xferOp);
947     } else {
948       rewriter.replaceOp(xferOp, vec);
949     }
950 
951     return success();
952   }
953 };
954 
955 /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
956 /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
957 /// memref buffer is allocated and the SCF loop is fully unrolled.
958 ///
959 /// ```
960 /// E.g.:
961 /// ```
962 /// vector.transfer_write %vec, %A[%a, %b, %c]
963 ///     : vector<5x4xf32>, memref<?x?x?xf32>
964 /// ```
965 /// is rewritten to IR such as (simplified):
966 /// ```
967 /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
968 /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
969 /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
970 /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
971 /// ...
972 /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
973 /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
974 /// ```
975 ///
976 /// Note: As an optimization, if the vector of the original TransferWriteOp
977 /// was directly extracted from another vector via an ExtractOp `a`, extract
978 /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
979 /// doing so, `a` may become dead, and the number of ExtractOps generated during
980 /// recursive application of this pattern will be minimal.
981 struct UnrollTransferWriteConversion
982     : public VectorToSCFPattern<TransferWriteOp> {
983   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
984 
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion985   void initialize() {
986     // This pattern recursively unpacks one dimension at a time. The recursion
987     // bounded as the rank is strictly decreasing.
988     setHasBoundedRewriteRecursion();
989   }
990 
991   /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion992   Value getDataVector(TransferWriteOp xferOp) const {
993     if (auto extractOp = getExtractOp(xferOp))
994       return extractOp.getVector();
995     return xferOp.getVector();
996   }
997 
998   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion999   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
1000     if (auto *op = xferOp.getVector().getDefiningOp())
1001       return dyn_cast<vector::ExtractOp>(op);
1002     return vector::ExtractOp();
1003   }
1004 
1005   /// If the input of the given TransferWriteOp is an ExtractOp, return its
1006   /// indices.
getExtractionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion1007   void getExtractionIndices(TransferWriteOp xferOp,
1008                             SmallVector<int64_t, 8> &indices) const {
1009     if (auto extractOp = getExtractOp(xferOp)) {
1010       for (Attribute attr : extractOp.getPosition())
1011         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
1012     }
1013   }
1014 
1015   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
1016   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion1017   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
1018                                 PatternRewriter &rewriter) const override {
1019     if (xferOp.getVectorType().getRank() <= options.targetRank)
1020       return failure();
1021     if (isTensorOp(xferOp) && !options.lowerTensors)
1022       return failure();
1023     // Transfer ops that modify the element type are not supported atm.
1024     if (xferOp.getVectorType().getElementType() !=
1025         xferOp.getShapedType().getElementType())
1026       return failure();
1027 
1028     auto vec = getDataVector(xferOp);
1029     auto xferVecType = xferOp.getVectorType();
1030     int64_t dimSize = xferVecType.getShape()[0];
1031     auto source = xferOp.getSource(); // memref or tensor to be written to.
1032     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
1033 
1034     // Generate fully unrolled loop of transfer ops.
1035     Location loc = xferOp.getLoc();
1036     for (int64_t i = 0; i < dimSize; ++i) {
1037       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
1038 
1039       auto updatedSource = generateInBoundsCheck(
1040           rewriter, xferOp, iv, unpackedDim(xferOp),
1041           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1042           /*inBoundsCase=*/
1043           [&](OpBuilder &b, Location loc) {
1044             // Indices for the new transfer op.
1045             SmallVector<Value, 8> xferIndices;
1046             getXferIndices(b, xferOp, iv, xferIndices);
1047 
1048             // Indices for the new vector.extract op.
1049             SmallVector<int64_t, 8> extractionIndices;
1050             getExtractionIndices(xferOp, extractionIndices);
1051             extractionIndices.push_back(i);
1052 
1053             auto extracted =
1054                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
1055             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
1056             auto newXferOp = b.create<vector::TransferWriteOp>(
1057                 loc, sourceType, extracted, source, xferIndices,
1058                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
1059                 inBoundsAttr);
1060 
1061             maybeAssignMask(b, xferOp, newXferOp, i);
1062 
1063             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1064           },
1065           /*outOfBoundsCase=*/
1066           [&](OpBuilder &b, Location loc) {
1067             return isTensorOp(xferOp) ? source : Value();
1068           });
1069 
1070       if (isTensorOp(xferOp))
1071         source = updatedSource;
1072     }
1073 
1074     if (isTensorOp(xferOp))
1075       rewriter.replaceOp(xferOp, source);
1076     else
1077       rewriter.eraseOp(xferOp);
1078 
1079     return success();
1080   }
1081 };
1082 
1083 } // namespace lowering_n_d_unrolled
1084 
1085 namespace lowering_1_d {
1086 
1087 /// Compute the indices into the memref for the LoadOp/StoreOp generated as
1088 /// part of TransferOp1dConversion. Return the memref dimension on which
1089 /// the transfer is operating. A return value of None indicates a broadcast.
1090 template <typename OpTy>
1091 static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)1092 get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
1093                    SmallVector<Value, 8> &memrefIndices) {
1094   auto indices = xferOp.getIndices();
1095   auto map = xferOp.getPermutationMap();
1096   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
1097 
1098   memrefIndices.append(indices.begin(), indices.end());
1099   assert(map.getNumResults() == 1 &&
1100          "Expected 1 permutation map result for 1D transfer");
1101   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1102     Location loc = xferOp.getLoc();
1103     auto dim = expr.getPosition();
1104     AffineExpr d0, d1;
1105     bindDims(xferOp.getContext(), d0, d1);
1106     Value offset = memrefIndices[dim];
1107     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1108     return dim;
1109   }
1110 
1111   assert(xferOp.isBroadcastDim(0) &&
1112          "Expected AffineDimExpr or AffineConstantExpr");
1113   return None;
1114 }
1115 
1116 /// Codegen strategy for TransferOp1dConversion, depending on the
1117 /// operation.
1118 template <typename OpTy>
1119 struct Strategy1d;
1120 
1121 /// Codegen strategy for TransferReadOp.
1122 template <>
1123 struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d1124   static void generateForLoopBody(OpBuilder &b, Location loc,
1125                                   TransferReadOp xferOp, Value iv,
1126                                   ValueRange loopState) {
1127     SmallVector<Value, 8> indices;
1128     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1129     auto vec = loopState[0];
1130 
1131     // In case of out-of-bounds access, leave `vec` as is (was initialized with
1132     // padding value).
1133     auto nextVec = generateInBoundsCheck(
1134         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
1135         /*inBoundsCase=*/
1136         [&](OpBuilder &b, Location loc) {
1137           Value val =
1138               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
1139           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
1140         },
1141         /*outOfBoundsCase=*/
1142         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1143     b.create<scf::YieldOp>(loc, nextVec);
1144   }
1145 
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d1146   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
1147     // Inititalize vector with padding value.
1148     Location loc = xferOp.getLoc();
1149     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
1150                                      xferOp.getPadding());
1151   }
1152 };
1153 
1154 /// Codegen strategy for TransferWriteOp.
1155 template <>
1156 struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d1157   static void generateForLoopBody(OpBuilder &b, Location loc,
1158                                   TransferWriteOp xferOp, Value iv,
1159                                   ValueRange /*loopState*/) {
1160     SmallVector<Value, 8> indices;
1161     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1162 
1163     // Nothing to do in case of out-of-bounds access.
1164     generateInBoundsCheck(
1165         b, xferOp, iv, dim,
1166         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1167           auto val =
1168               b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
1169           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
1170         });
1171     b.create<scf::YieldOp>(loc);
1172   }
1173 
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d1174   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1175     return Value();
1176   }
1177 };
1178 
1179 /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)1180 static bool isLastMemrefDimUnitStride(MemRefType type) {
1181   int64_t offset;
1182   SmallVector<int64_t, 4> strides;
1183   auto successStrides = getStridesAndOffset(type, strides, offset);
1184   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
1185 }
1186 
1187 /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
1188 /// necessary in cases where a 1D vector transfer op cannot be lowered into
1189 /// vector load/stores due to non-unit strides or broadcasts:
1190 ///
1191 /// * Transfer dimension is not the last memref dimension
1192 /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
1193 /// * Memref has a layout map with non-unit stride on the last dimension
1194 ///
1195 /// This pattern generates IR as follows:
1196 ///
1197 /// 1. Generate a for loop iterating over each vector element.
1198 /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
1199 ///    depending on OpTy.
1200 ///
1201 /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
1202 ///       can be generated instead of TransferOp1dConversion. Add such a pattern
1203 ///       to ConvertVectorToLLVM.
1204 ///
1205 /// E.g.:
1206 /// ```
1207 /// vector.transfer_write %vec, %A[%a, %b]
1208 ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
1209 ///    : vector<9xf32>, memref<?x?xf32>
1210 /// ```
1211 /// Is rewritten to approximately the following pseudo-IR:
1212 /// ```
1213 /// for i = 0 to 9 {
1214 ///   %t = vector.extractelement %vec[i] : vector<9xf32>
1215 ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
1216 /// }
1217 /// ```
1218 template <typename OpTy>
1219 struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1220   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
1221 
matchAndRewrite__anon4d9edda10111::lowering_1_d::TransferOp1dConversion1222   LogicalResult matchAndRewrite(OpTy xferOp,
1223                                 PatternRewriter &rewriter) const override {
1224     // TODO: support 0-d corner case.
1225     if (xferOp.getTransferRank() == 0)
1226       return failure();
1227     auto map = xferOp.getPermutationMap();
1228     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
1229 
1230     if (!memRefType)
1231       return failure();
1232     if (xferOp.getVectorType().getRank() != 1)
1233       return failure();
1234     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
1235       return failure(); // Handled by ConvertVectorToLLVM
1236 
1237     // Loop bounds, step, state...
1238     Location loc = xferOp.getLoc();
1239     auto vecType = xferOp.getVectorType();
1240     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1241     auto ub =
1242         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1243     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
1244     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
1245 
1246     // Generate for loop.
1247     rewriter.replaceOpWithNewOp<scf::ForOp>(
1248         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1249         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1250           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
1251         });
1252 
1253     return success();
1254   }
1255 };
1256 
1257 } // namespace lowering_1_d
1258 } // namespace
1259 
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)1260 void mlir::populateVectorToSCFConversionPatterns(
1261     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
1262   if (options.unroll) {
1263     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1264                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
1265         patterns.getContext(), options);
1266   } else {
1267     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1268                  lowering_n_d::PrepareTransferWriteConversion,
1269                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1270                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1271         patterns.getContext(), options);
1272   }
1273 
1274   if (options.targetRank == 1) {
1275     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1276                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1277         patterns.getContext(), options);
1278   }
1279 }
1280 
1281 namespace {
1282 
1283 struct ConvertVectorToSCFPass
1284     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
1285   ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon4d9edda11411::ConvertVectorToSCFPass1286   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1287     this->fullUnroll = options.unroll;
1288     this->targetRank = options.targetRank;
1289     this->lowerPermutationMaps = options.lowerPermutationMaps;
1290     this->lowerTensors = options.lowerTensors;
1291   }
1292 
runOnOperation__anon4d9edda11411::ConvertVectorToSCFPass1293   void runOnOperation() override {
1294     VectorTransferToSCFOptions options;
1295     options.unroll = fullUnroll;
1296     options.targetRank = targetRank;
1297     options.lowerPermutationMaps = lowerPermutationMaps;
1298     options.lowerTensors = lowerTensors;
1299 
1300     // Lower permutation maps first.
1301     if (lowerPermutationMaps) {
1302       RewritePatternSet lowerTransferPatterns(&getContext());
1303       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1304           lowerTransferPatterns);
1305       (void)applyPatternsAndFoldGreedily(getOperation(),
1306                                          std::move(lowerTransferPatterns));
1307     }
1308 
1309     RewritePatternSet patterns(&getContext());
1310     populateVectorToSCFConversionPatterns(patterns, options);
1311     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
1312   }
1313 };
1314 
1315 } // namespace
1316 
1317 std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)1318 mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
1319   return std::make_unique<ConvertVectorToSCFPass>(options);
1320 }
1321