10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
18*6825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
19*6825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/Utils.h"
20*6825bfe2SNicolas Vasilache #include "mlir/Dialect/SCF/SCF.h"
214ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h"
227c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h"
234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
24*6825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
284ead2cf7SAlex Zinenko 
294ead2cf7SAlex Zinenko using namespace mlir;
304ead2cf7SAlex Zinenko using vector::TransferReadOp;
314ead2cf7SAlex Zinenko using vector::TransferWriteOp;
324ead2cf7SAlex Zinenko 
33350dadaaSBenjamin Kramer namespace {
340f241638SMatthias Springer 
350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
370f241638SMatthias Springer 
382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
392ca887deSMatthias Springer /// VectorTransferToSCFOptions.
402ca887deSMatthias Springer template <typename OpTy>
412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
422ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
432ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
442ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
452ca887deSMatthias Springer 
462ca887deSMatthias Springer   VectorTransferToSCFOptions options;
472ca887deSMatthias Springer };
480f241638SMatthias Springer 
490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
510f241638SMatthias Springer /// A return value of None indicates a broadcast.
520f241638SMatthias Springer template <typename OpTy>
530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
540f241638SMatthias Springer   auto map = xferOp.permutation_map();
550f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
560f241638SMatthias Springer     return expr.getPosition();
577c3c5b11SNicolas Vasilache   }
580f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
590f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
600f241638SMatthias Springer   return None;
610f241638SMatthias Springer }
620f241638SMatthias Springer 
630f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
640f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
650f241638SMatthias Springer /// omitted.
660f241638SMatthias Springer template <typename OpTy>
67*6825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
680f241638SMatthias Springer   auto map = xferOp.permutation_map();
690f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
70*6825bfe2SNicolas Vasilache                         b.getContext());
710f241638SMatthias Springer }
720f241638SMatthias Springer 
730f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
740f241638SMatthias Springer ///
750f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
760f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
770f241638SMatthias Springer ///                                 ^^^^^^
780f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
790f241638SMatthias Springer template <typename OpTy>
80*6825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
810f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
820f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
830f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
840f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
850f241638SMatthias Springer   auto prevIndices = adaptor.indices();
860f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
870f241638SMatthias Springer 
88*6825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
890f241638SMatthias Springer   bool isBroadcast = !dim.hasValue();
900f241638SMatthias Springer   if (!isBroadcast) {
91*6825bfe2SNicolas Vasilache     AffineExpr d0, d1;
92*6825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
93*6825bfe2SNicolas Vasilache     Value offset = adaptor.indices()[dim.getValue()];
94*6825bfe2SNicolas Vasilache     indices[dim.getValue()] =
95*6825bfe2SNicolas Vasilache         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
960f241638SMatthias Springer   }
970f241638SMatthias Springer }
980f241638SMatthias Springer 
99*6825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1000f241638SMatthias Springer                             Value value) {
1010f241638SMatthias Springer   if (hasRetVal) {
102*6825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1030f241638SMatthias Springer   } else {
104*6825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1050f241638SMatthias Springer   }
1060f241638SMatthias Springer }
1070f241638SMatthias Springer 
1080f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1090f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1100f241638SMatthias Springer /// * xferOp does not have a mask.
1110f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1120f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1130f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1140f241638SMatthias Springer template <typename OpTy>
115*6825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1160f241638SMatthias Springer   if (!xferOp.mask())
1170f241638SMatthias Springer     return Value();
1180f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1190f241638SMatthias Springer     return Value();
1200f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1210f241638SMatthias Springer     return Value();
1220f241638SMatthias Springer 
123*6825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
124*6825bfe2SNicolas Vasilache   Value ivI32 =
125*6825bfe2SNicolas Vasilache       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
126*6825bfe2SNicolas Vasilache   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
1270f241638SMatthias Springer }
1280f241638SMatthias Springer 
1290f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1300f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1310f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1320f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1330f241638SMatthias Springer /// ```
1340f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1350f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1360f241638SMatthias Springer /// ```
1370f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1380f241638SMatthias Springer /// ```
1390f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1400f241638SMatthias Springer /// if (%a + iv < %d) {
1410f241638SMatthias Springer ///   (in-bounds case)
1420f241638SMatthias Springer /// } else {
1430f241638SMatthias Springer ///   (out-of-bounds case)
1440f241638SMatthias Springer /// }
1450f241638SMatthias Springer /// ```
1460f241638SMatthias Springer ///
1470f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1480f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1490f241638SMatthias Springer ///
1500f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1510f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1520f241638SMatthias Springer /// `resultTypes`.
1530f241638SMatthias Springer template <typename OpTy>
1540f241638SMatthias Springer static Value generateInBoundsCheck(
155*6825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1560f241638SMatthias Springer     TypeRange resultTypes,
1570f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1580f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1590f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1600f241638SMatthias Springer   Value cond; // Condition to be built...
1610f241638SMatthias Springer 
1620f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1630f241638SMatthias Springer   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
164*6825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
165*6825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1660f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
167*6825bfe2SNicolas Vasilache     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
168*6825bfe2SNicolas Vasilache     AffineExpr d0, d1;
169*6825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
170*6825bfe2SNicolas Vasilache     Value base = xferOp.indices()[dim.getValue()];
171*6825bfe2SNicolas Vasilache     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
172*6825bfe2SNicolas Vasilache     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
1730f241638SMatthias Springer   }
1740f241638SMatthias Springer 
1750f241638SMatthias Springer   // Condition check 2: Masked in?
176*6825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
177*6825bfe2SNicolas Vasilache     if (cond)
178*6825bfe2SNicolas Vasilache       cond = lb.create<AndOp>(cond, maskCond);
179*6825bfe2SNicolas Vasilache     else
1800f241638SMatthias Springer       cond = maskCond;
1810f241638SMatthias Springer   }
1820f241638SMatthias Springer 
1830f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
1840f241638SMatthias Springer   if (cond) {
185*6825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
186*6825bfe2SNicolas Vasilache         resultTypes, cond,
1870f241638SMatthias Springer         /*thenBuilder=*/
188*6825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
189*6825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
190cadb7ccfSAlex Zinenko         },
1910f241638SMatthias Springer         /*elseBuilder=*/
192*6825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1930f241638SMatthias Springer           if (outOfBoundsCase) {
194*6825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
1957c3c5b11SNicolas Vasilache           } else {
196*6825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
1977c3c5b11SNicolas Vasilache           }
1987c3c5b11SNicolas Vasilache         });
1997c3c5b11SNicolas Vasilache 
2000f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2014ead2cf7SAlex Zinenko   }
2024ead2cf7SAlex Zinenko 
2030f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
204*6825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2050f241638SMatthias Springer }
2060f241638SMatthias Springer 
2070f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2080f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2090f241638SMatthias Springer template <typename OpTy>
2100f241638SMatthias Springer static void generateInBoundsCheck(
211*6825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2120f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2130f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2140f241638SMatthias Springer   generateInBoundsCheck(
215*6825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2160f241638SMatthias Springer       /*inBoundsCase=*/
217*6825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
218*6825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2190f241638SMatthias Springer         return Value();
2200f241638SMatthias Springer       },
2210f241638SMatthias Springer       /*outOfBoundsCase=*/
222*6825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2230f241638SMatthias Springer         if (outOfBoundsCase)
224*6825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2250f241638SMatthias Springer         return Value();
2260f241638SMatthias Springer       });
2270f241638SMatthias Springer }
2280f241638SMatthias Springer 
2290f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
230*6825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2310f241638SMatthias Springer   if (!attr)
2320f241638SMatthias Springer     return attr;
233*6825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2340f241638SMatthias Springer }
2350f241638SMatthias Springer 
2360f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2370f241638SMatthias Springer /// rank.
2380f241638SMatthias Springer template <typename OpTy>
239*6825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2402ca887deSMatthias Springer                                 unsigned targetRank) {
2412ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
242*6825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2430f241638SMatthias Springer }
2440f241638SMatthias Springer 
245a088bed4SMatthias Springer namespace lowering_n_d {
246a088bed4SMatthias Springer 
247a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
248a088bed4SMatthias Springer struct BufferAllocs {
249a088bed4SMatthias Springer   Value dataBuffer;
250a088bed4SMatthias Springer   Value maskBuffer;
251a088bed4SMatthias Springer };
252a088bed4SMatthias Springer 
253a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
254a088bed4SMatthias Springer /// TODO: Parallelism and threadlocal considerations.
255a088bed4SMatthias Springer template <typename OpTy>
256*6825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
257*6825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
258a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
259a088bed4SMatthias Springer   Operation *scope =
260a088bed4SMatthias Springer       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
261a088bed4SMatthias Springer   assert(scope && "Expected op to be inside automatic allocation scope");
262a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
263a088bed4SMatthias Springer 
264a088bed4SMatthias Springer   BufferAllocs result;
265a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
266*6825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
267a088bed4SMatthias Springer 
268a088bed4SMatthias Springer   if (xferOp.mask()) {
269a088bed4SMatthias Springer     auto maskType = MemRefType::get({}, xferOp.mask().getType());
270*6825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
271fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
272*6825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
273*6825bfe2SNicolas Vasilache     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
274a088bed4SMatthias Springer   }
275a088bed4SMatthias Springer 
276a088bed4SMatthias Springer   return result;
277a088bed4SMatthias Springer }
278a088bed4SMatthias Springer 
279a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
280a088bed4SMatthias Springer /// the VectorType into the MemRefType.
281a088bed4SMatthias Springer ///
282a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
283a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
284a088bed4SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
285a088bed4SMatthias Springer   auto memrefShape = type.getShape();
286a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
287a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
288a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
289a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
290a088bed4SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
291a088bed4SMatthias Springer                                          vectorType.getElementType()));
292a088bed4SMatthias Springer }
293a088bed4SMatthias Springer 
2940f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
2950f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
2960f241638SMatthias Springer template <typename OpTy>
2970f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
2980f241638SMatthias Springer   assert(xferOp.mask() && "Expected that transfer op has mask");
2990f241638SMatthias Springer   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
3000f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3010f241638SMatthias Springer   return loadOp.getMemRef();
3020f241638SMatthias Springer }
3030f241638SMatthias Springer 
3040f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3050f241638SMatthias Springer template <typename OpTy>
3060f241638SMatthias Springer struct Strategy;
3070f241638SMatthias Springer 
3080f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3094ead2cf7SAlex Zinenko template <>
3100f241638SMatthias Springer struct Strategy<TransferReadOp> {
3110f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3120f241638SMatthias Springer   /// result to the temporary buffer allocation.
3130f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3140f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3150f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3160f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3170f241638SMatthias Springer     return storeOp;
3187c3c5b11SNicolas Vasilache   }
3194ead2cf7SAlex Zinenko 
3200f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3210f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3220f241638SMatthias Springer   /// of the buffer allocation:
3230f241638SMatthias Springer   /// ```
3240f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3250f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3260f241638SMatthias Springer   /// ```
3270f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3280f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3291870e787SNicolas Vasilache   }
3300f241638SMatthias Springer 
3310f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3320f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3330f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3340f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
3350f241638SMatthias Springer     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
3360f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3370f241638SMatthias Springer   }
3380f241638SMatthias Springer 
3390f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3400f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3410f241638SMatthias Springer   ///
3420f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3430f241638SMatthias Springer   ///    variable `iv`.
3440f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3450f241638SMatthias Springer   ///
3460f241638SMatthias Springer   /// E.g.:
3470f241638SMatthias Springer   /// ```
3480f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3490f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3500f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3510f241638SMatthias Springer   /// ```
3520f241638SMatthias Springer   /// Is rewritten to:
3530f241638SMatthias Springer   /// ```
3540f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3550f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3560f241638SMatthias Springer   /// for %j = 0 to 4 {
3570f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3580f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3590f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3600f241638SMatthias Springer   /// }
3610f241638SMatthias Springer   /// ```
3620f241638SMatthias Springer   ///
3630f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3640f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3650f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
366*6825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
3672ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
3682ca887deSMatthias Springer                                   TransferReadOp xferOp, Value buffer,
3692ca887deSMatthias Springer                                   Value iv) {
3700f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3710f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3720f241638SMatthias Springer     storeIndices.push_back(iv);
3730f241638SMatthias Springer 
3740f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
375*6825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
3760f241638SMatthias Springer 
377*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
3780f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3790f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
380*6825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
381*6825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
382*6825bfe2SNicolas Vasilache         loc, vecType, xferOp.source(), xferIndices,
383*6825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
384*6825bfe2SNicolas Vasilache         Value(), inBoundsAttr);
3850f241638SMatthias Springer 
386*6825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
3870f241638SMatthias Springer 
388*6825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
389*6825bfe2SNicolas Vasilache     return newXferOp;
3900f241638SMatthias Springer   }
3910f241638SMatthias Springer 
3920f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
3930f241638SMatthias Springer   /// padding value to the temporary buffer.
394*6825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
395*6825bfe2SNicolas Vasilache                                    Value buffer, Value iv) {
3960f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3970f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3980f241638SMatthias Springer     storeIndices.push_back(iv);
3990f241638SMatthias Springer 
400*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4010f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4020f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
403*6825bfe2SNicolas Vasilache     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
404*6825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
4050f241638SMatthias Springer   }
4060f241638SMatthias Springer 
4070f241638SMatthias Springer   /// Cleanup after rewriting the op.
4080f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
4090f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4100f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4110f241638SMatthias Springer   }
4124ead2cf7SAlex Zinenko };
4137c3c5b11SNicolas Vasilache 
4140f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4150f241638SMatthias Springer template <>
4160f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4170f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4180f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4190f241638SMatthias Springer   /// of the buffer allocation:
4200f241638SMatthias Springer   /// ```
4210f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4220f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4230f241638SMatthias Springer   /// ```
4240f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4250f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4260f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4270f241638SMatthias Springer     return loadOp.getMemRef();
4287c3c5b11SNicolas Vasilache   }
4294ead2cf7SAlex Zinenko 
4300f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4310f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4320f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4330f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4340f241638SMatthias Springer     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
4350f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4360f241638SMatthias Springer   }
4370f241638SMatthias Springer 
4380f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4390f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4400f241638SMatthias Springer   ///
4410f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4420f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4430f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4440f241638SMatthias Springer   ///    to memory.
4450f241638SMatthias Springer   ///
4460f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
447*6825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
4482ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
4492ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
4502ca887deSMatthias Springer                                    Value iv) {
4510f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4520f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4530f241638SMatthias Springer     loadIndices.push_back(iv);
4540f241638SMatthias Springer 
4550f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
456*6825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4570f241638SMatthias Springer 
458*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
459*6825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
460*6825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
461*6825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
462*6825bfe2SNicolas Vasilache         loc, Type(), vec, xferOp.source(), xferIndices,
463*6825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4640f241638SMatthias Springer         inBoundsAttr);
4650f241638SMatthias Springer 
466*6825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4670f241638SMatthias Springer 
468*6825bfe2SNicolas Vasilache     return newXferOp;
4690f241638SMatthias Springer   }
4700f241638SMatthias Springer 
4710f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
472*6825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
4730f241638SMatthias Springer                                    Value buffer, Value iv) {}
4740f241638SMatthias Springer 
4750f241638SMatthias Springer   /// Cleanup after rewriting the op.
4760f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
4770f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4780f241638SMatthias Springer   }
4790f241638SMatthias Springer };
4800f241638SMatthias Springer 
4810f241638SMatthias Springer template <typename OpTy>
482fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
483fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
4840f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
4850f241638SMatthias Springer     return failure();
486fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
4870f241638SMatthias Springer     return failure();
4880f241638SMatthias Springer   return success();
4890f241638SMatthias Springer }
4900f241638SMatthias Springer 
4910f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
4920f241638SMatthias Springer ///
4930f241638SMatthias Springer /// 1. Allocate a temporary buffer.
4940f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
4950f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
4960f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
4970f241638SMatthias Springer ///    original TransferReadOp with this load.
4980f241638SMatthias Springer ///
4990f241638SMatthias Springer /// E.g.:
5000f241638SMatthias Springer /// ```
5010f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5020f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5030f241638SMatthias Springer /// ```
5040f241638SMatthias Springer /// is rewritten to:
5050f241638SMatthias Springer /// ```
5060f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5070f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5080f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5090f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5100f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5110f241638SMatthias Springer /// ```
5120f241638SMatthias Springer ///
5130f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5142ca887deSMatthias Springer struct PrepareTransferReadConversion
5152ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5162ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5170f241638SMatthias Springer 
5180f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5190f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
520fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5210f241638SMatthias Springer       return failure();
5220f241638SMatthias Springer 
523*6825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5240f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5250f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5260f241638SMatthias Springer     if (xferOp.mask()) {
5270f241638SMatthias Springer       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
5280f241638SMatthias Springer           buffers.maskBuffer);
5290f241638SMatthias Springer     }
5300f241638SMatthias Springer 
531*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
532*6825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
533*6825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
5340f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5354ead2cf7SAlex Zinenko 
5364ead2cf7SAlex Zinenko     return success();
5374ead2cf7SAlex Zinenko   }
5380f241638SMatthias Springer };
5390f241638SMatthias Springer 
5400f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5410f241638SMatthias Springer ///
5420f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5430f241638SMatthias Springer /// 2. Store the vector into the buffer.
5440f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5450f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
5460f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
5470f241638SMatthias Springer ///
5480f241638SMatthias Springer /// E.g.:
5490f241638SMatthias Springer /// ```
5500f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
5510f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5520f241638SMatthias Springer /// ```
5530f241638SMatthias Springer /// is rewritten to:
5540f241638SMatthias Springer /// ```
5550f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5560f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
5570f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
5580f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
5590f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5600f241638SMatthias Springer /// ```
5610f241638SMatthias Springer ///
5620f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5630f241638SMatthias Springer struct PrepareTransferWriteConversion
5642ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
5652ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
5660f241638SMatthias Springer 
5670f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
5680f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
569fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5700f241638SMatthias Springer       return failure();
5710f241638SMatthias Springer 
572*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
573*6825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
574*6825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
575*6825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
5760f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
5770f241638SMatthias Springer       xferOp.vectorMutable().assign(loadedVec);
5780f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
5790f241638SMatthias Springer     });
5800f241638SMatthias Springer 
5810f241638SMatthias Springer     if (xferOp.mask()) {
5820f241638SMatthias Springer       rewriter.updateRootInPlace(
5830f241638SMatthias Springer           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
5840f241638SMatthias Springer     }
5850f241638SMatthias Springer 
5860f241638SMatthias Springer     return success();
5870f241638SMatthias Springer   }
5880f241638SMatthias Springer };
5890f241638SMatthias Springer 
5900f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
5910f241638SMatthias Springer ///
5920f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
5930f241638SMatthias Springer ///    to that new type. E.g.:
5940f241638SMatthias Springer ///    ```
5950f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
5960f241638SMatthias Springer ///    vector.transfer_write %vec ...
5970f241638SMatthias Springer ///    ```
5980f241638SMatthias Springer ///    The following cast is generated:
5990f241638SMatthias Springer ///    ```
6000f241638SMatthias Springer ///    %casted = vector.type_cast %0
6010f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6020f241638SMatthias Springer ///    ```
6030f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6040f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6050f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
6060f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
6070f241638SMatthias Springer template <typename OpTy>
6082ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6092ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6100f241638SMatthias Springer 
6110f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6120f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6130f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6140f241638SMatthias Springer       return failure();
6150f241638SMatthias Springer 
6160f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
617*6825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6180f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6190f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6200f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
621*6825bfe2SNicolas Vasilache     auto castedDataBuffer =
622*6825bfe2SNicolas Vasilache         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6230f241638SMatthias Springer 
6240f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6250f241638SMatthias Springer     Value castedMaskBuffer;
6260f241638SMatthias Springer     if (xferOp.mask()) {
6270f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
6280f241638SMatthias Springer       auto maskBufferType =
6290f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
6300f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
6310f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
6320f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
6330f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
6340f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
6350f241638SMatthias Springer         //   be broadcasted.)
6360f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
6370f241638SMatthias Springer       } else {
6380f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
639*6825bfe2SNicolas Vasilache         castedMaskBuffer =
640*6825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
6410f241638SMatthias Springer       }
6420f241638SMatthias Springer     }
6430f241638SMatthias Springer 
6440f241638SMatthias Springer     // Loop bounds and step.
645*6825bfe2SNicolas Vasilache     auto lb = locB.create<ConstantIndexOp>(0);
646*6825bfe2SNicolas Vasilache     auto ub = locB.create<ConstantIndexOp>(
647*6825bfe2SNicolas Vasilache         castedDataType.getDimSize(castedDataType.getRank() - 1));
648*6825bfe2SNicolas Vasilache     auto step = locB.create<ConstantIndexOp>(1);
6490f241638SMatthias Springer 
6500f241638SMatthias Springer     // Generate for loop.
651*6825bfe2SNicolas Vasilache     locB.create<scf::ForOp>(
652*6825bfe2SNicolas Vasilache         lb, ub, step, ValueRange(),
6530f241638SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
6540f241638SMatthias Springer           generateInBoundsCheck(
655*6825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
6560f241638SMatthias Springer               /*inBoundsCase=*/
657*6825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
6580f241638SMatthias Springer                 // Create new transfer op.
6592ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
6602ca887deSMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv);
6610f241638SMatthias Springer 
6620f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
6630f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
6640f241638SMatthias Springer                 // the
6650f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
6660f241638SMatthias Springer                 //               needed on the new transfer op.
6670f241638SMatthias Springer                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
6680f241638SMatthias Springer                                       xferOp.getMaskType().getRank() > 1)) {
6690f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
6700f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
6710f241638SMatthias Springer 
6720f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
6730f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
6740f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
6750f241638SMatthias Springer                   // as before.
6760f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
6770f241638SMatthias Springer                     loadIndices.push_back(iv);
6780f241638SMatthias Springer 
679*6825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
680*6825bfe2SNicolas Vasilache                                                        loadIndices);
6810f241638SMatthias Springer                   rewriter.updateRootInPlace(
6820f241638SMatthias Springer                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
6830f241638SMatthias Springer                 }
6840f241638SMatthias Springer               },
6850f241638SMatthias Springer               /*outOfBoundsCase=*/
6860f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
6870f241638SMatthias Springer                 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
6880f241638SMatthias Springer                                                      castedDataBuffer, iv);
6890f241638SMatthias Springer               });
6900f241638SMatthias Springer           b.create<scf::YieldOp>(loc);
6910f241638SMatthias Springer         });
6920f241638SMatthias Springer 
6930f241638SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp);
6940f241638SMatthias Springer     return success();
6950f241638SMatthias Springer   }
6960f241638SMatthias Springer };
6970f241638SMatthias Springer 
698a088bed4SMatthias Springer } // namespace lowering_n_d
699a088bed4SMatthias Springer 
700a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
701a088bed4SMatthias Springer 
7020f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7030f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7040f241638SMatthias Springer template <typename OpTy>
705*6825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7060f241638SMatthias Springer                             int64_t i) {
7070f241638SMatthias Springer   if (!xferOp.mask())
7080f241638SMatthias Springer     return;
7090f241638SMatthias Springer 
7100f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7110f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7120f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7130f241638SMatthias Springer     newXferOp.maskMutable().assign(xferOp.mask());
7140f241638SMatthias Springer     return;
7150f241638SMatthias Springer   }
7160f241638SMatthias Springer 
7170f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
7180f241638SMatthias Springer     // Unpack one dimension of the mask.
719*6825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
720*6825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
7210f241638SMatthias Springer 
7220f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
723*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
724*6825bfe2SNicolas Vasilache     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
7250f241638SMatthias Springer     newXferOp.maskMutable().assign(newMask);
7260f241638SMatthias Springer   }
7270f241638SMatthias Springer 
7280f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
7290f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
7300f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
7310f241638SMatthias Springer }
7320f241638SMatthias Springer 
7330f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
7340f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
7350f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
7360f241638SMatthias Springer ///
7370f241638SMatthias Springer /// ```
7380f241638SMatthias Springer /// E.g.:
7390f241638SMatthias Springer /// ```
7400f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
7410f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
7420f241638SMatthias Springer /// ```
7430f241638SMatthias Springer /// is rewritten to IR such as (simplified):
7440f241638SMatthias Springer /// ```
7450f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
7460f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
7470f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7480f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
7490f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
7500f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7510f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
7520f241638SMatthias Springer /// ...
7530f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
7540f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7550f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
7560f241638SMatthias Springer /// ```
7570f241638SMatthias Springer ///
7580f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
7590f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
7600f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
7612ca887deSMatthias Springer struct UnrollTransferReadConversion
7622ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
7632ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
7640f241638SMatthias Springer 
7650f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
7660f241638SMatthias Springer   /// are inserted.
7670f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
7680f241638SMatthias Springer                         PatternRewriter &rewriter) const {
7690f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
7700f241638SMatthias Springer       return insertOp.dest();
771*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
772*6825bfe2SNicolas Vasilache     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
773*6825bfe2SNicolas Vasilache                                     xferOp.padding());
7740f241638SMatthias Springer   }
7750f241638SMatthias Springer 
7760f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7770f241638SMatthias Springer   /// vector::InsertOp, return that operation.
7780f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
7790f241638SMatthias Springer     if (xferOp->hasOneUse()) {
7800f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
7810f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
7820f241638SMatthias Springer         return insertOp;
7830f241638SMatthias Springer     }
7840f241638SMatthias Springer 
7850f241638SMatthias Springer     return vector::InsertOp();
7860f241638SMatthias Springer   }
7870f241638SMatthias Springer 
7880f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7890f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
7900f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
7910f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
7920f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
7930f241638SMatthias Springer       llvm::for_each(insertOp.position(), [&](Attribute attr) {
7940f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
7950f241638SMatthias Springer       });
7960f241638SMatthias Springer     }
7970f241638SMatthias Springer   }
7980f241638SMatthias Springer 
7990f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8000f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
8010f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
8020f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
8032ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
8040f241638SMatthias Springer       return failure();
8050f241638SMatthias Springer 
8060f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
8070f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
8080f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
8090f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
8100f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
8110f241638SMatthias Springer                                           xferVecType.getElementType());
8120f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
8130f241638SMatthias Springer 
8140f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
815*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8160f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
817*6825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
8180f241638SMatthias Springer 
8190f241638SMatthias Springer       vec = generateInBoundsCheck(
820*6825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
8210f241638SMatthias Springer           /*inBoundsCase=*/
8220f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8230f241638SMatthias Springer             // Indices for the new transfer op.
8240f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
825*6825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
8260f241638SMatthias Springer 
8270f241638SMatthias Springer             // Indices for the new vector.insert op.
8280f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
8290f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
8300f241638SMatthias Springer             insertionIndices.push_back(i);
8310f241638SMatthias Springer 
8320f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
833*6825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
834*6825bfe2SNicolas Vasilache                 loc, newXferVecType, xferOp.source(), xferIndices,
835*6825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
836*6825bfe2SNicolas Vasilache                 xferOp.padding(), Value(), inBoundsAttr);
8370f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
838*6825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
839*6825bfe2SNicolas Vasilache                                               insertionIndices);
8400f241638SMatthias Springer           },
8410f241638SMatthias Springer           /*outOfBoundsCase=*/
8420f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8430f241638SMatthias Springer             // Loop through original (unmodified) vector.
8440f241638SMatthias Springer             return vec;
8450f241638SMatthias Springer           });
8460f241638SMatthias Springer     }
8470f241638SMatthias Springer 
8480f241638SMatthias Springer     if (insertOp) {
8490f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
8500f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
8510f241638SMatthias Springer       rewriter.eraseOp(xferOp);
8520f241638SMatthias Springer     } else {
8530f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
8540f241638SMatthias Springer     }
8550f241638SMatthias Springer 
8560f241638SMatthias Springer     return success();
8570f241638SMatthias Springer   }
8580f241638SMatthias Springer };
8590f241638SMatthias Springer 
8600f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
8610f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
8620f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8630f241638SMatthias Springer ///
8640f241638SMatthias Springer /// ```
8650f241638SMatthias Springer /// E.g.:
8660f241638SMatthias Springer /// ```
8670f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
8680f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
8690f241638SMatthias Springer /// ```
8700f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8710f241638SMatthias Springer /// ```
8720f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
8730f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
8740f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
8750f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
8760f241638SMatthias Springer /// ...
8770f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
8780f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
8790f241638SMatthias Springer /// ```
8800f241638SMatthias Springer ///
8810f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
8820f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
8830f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
8840f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
8850f241638SMatthias Springer /// recursive application of this pattern will be minimal.
8860f241638SMatthias Springer struct UnrollTransferWriteConversion
8872ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
8882ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
8890f241638SMatthias Springer 
8900f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
8910f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
8920f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
8930f241638SMatthias Springer       return extractOp.vector();
8940f241638SMatthias Springer     return xferOp.vector();
8950f241638SMatthias Springer   }
8960f241638SMatthias Springer 
8970f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
8980f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
8990f241638SMatthias Springer     if (auto *op = xferOp.vector().getDefiningOp())
9000f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
9010f241638SMatthias Springer     return vector::ExtractOp();
9020f241638SMatthias Springer   }
9030f241638SMatthias Springer 
9040f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
9050f241638SMatthias Springer   /// indices.
9060f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
9070f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
9080f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
9090f241638SMatthias Springer       llvm::for_each(extractOp.position(), [&](Attribute attr) {
9100f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
9110f241638SMatthias Springer       });
9120f241638SMatthias Springer     }
9130f241638SMatthias Springer   }
9140f241638SMatthias Springer 
9150f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
9160f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
9170f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
9180f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
9192ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
9200f241638SMatthias Springer       return failure();
9210f241638SMatthias Springer 
9220f241638SMatthias Springer     auto vec = getDataVector(xferOp);
9230f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
9240f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
9250f241638SMatthias Springer 
9260f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
927*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9280f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
929*6825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
9300f241638SMatthias Springer 
9310f241638SMatthias Springer       generateInBoundsCheck(
932*6825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
9330f241638SMatthias Springer           /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
9340f241638SMatthias Springer             // Indices for the new transfer op.
9350f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
936*6825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
9370f241638SMatthias Springer 
9380f241638SMatthias Springer             // Indices for the new vector.extract op.
9390f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
9400f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
9410f241638SMatthias Springer             extractionIndices.push_back(i);
9420f241638SMatthias Springer 
943*6825bfe2SNicolas Vasilache             auto extracted =
944*6825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
9450f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
9460f241638SMatthias Springer 
947*6825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
948*6825bfe2SNicolas Vasilache                 loc, Type(), extracted, xferOp.source(), xferIndices,
949*6825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
950*6825bfe2SNicolas Vasilache                 inBoundsAttr);
9510f241638SMatthias Springer 
9520f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9530f241638SMatthias Springer           });
9540f241638SMatthias Springer     }
9550f241638SMatthias Springer 
9560f241638SMatthias Springer     rewriter.eraseOp(xferOp);
9570f241638SMatthias Springer     return success();
9580f241638SMatthias Springer   }
9590f241638SMatthias Springer };
9600f241638SMatthias Springer 
961a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
962a088bed4SMatthias Springer 
963a088bed4SMatthias Springer namespace lowering_1_d {
964a088bed4SMatthias Springer 
9650f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
9660f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
9670f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
9680f241638SMatthias Springer template <typename OpTy>
9690f241638SMatthias Springer static Optional<int64_t>
970*6825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
9710f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
9720f241638SMatthias Springer   auto indices = xferOp.indices();
9730f241638SMatthias Springer   auto map = xferOp.permutation_map();
9740f241638SMatthias Springer 
9750f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
9760f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
9770f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
9780f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
979*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9800f241638SMatthias Springer     auto dim = expr.getPosition();
981*6825bfe2SNicolas Vasilache     AffineExpr d0, d1;
982*6825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
983*6825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
984*6825bfe2SNicolas Vasilache     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
9850f241638SMatthias Springer     return dim;
9860f241638SMatthias Springer   }
9870f241638SMatthias Springer 
9880f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
9890f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
9900f241638SMatthias Springer   return None;
9910f241638SMatthias Springer }
9920f241638SMatthias Springer 
9930f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
9940f241638SMatthias Springer /// operation.
9950f241638SMatthias Springer template <typename OpTy>
9960f241638SMatthias Springer struct Strategy1d;
9970f241638SMatthias Springer 
9980f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
9990f241638SMatthias Springer template <>
10000f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
1001*6825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10020f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
10030f241638SMatthias Springer                                   ValueRange loopState) {
10040f241638SMatthias Springer     SmallVector<Value, 8> indices;
1005*6825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1006*6825bfe2SNicolas Vasilache     Value ivI32 =
1007*6825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10080f241638SMatthias Springer     auto vec = loopState[0];
10090f241638SMatthias Springer 
10100f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
10110f241638SMatthias Springer     // padding value).
10120f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
1013*6825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
10140f241638SMatthias Springer         /*inBoundsCase=*/
1015*6825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1016*6825bfe2SNicolas Vasilache           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
1017*6825bfe2SNicolas Vasilache           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
10180f241638SMatthias Springer         },
10190f241638SMatthias Springer         /*outOfBoundsCase=*/
10200f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
1021*6825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
10220f241638SMatthias Springer   }
10230f241638SMatthias Springer 
1024*6825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
10250f241638SMatthias Springer     // Inititalize vector with padding value.
1026*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
1027*6825bfe2SNicolas Vasilache     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
10280f241638SMatthias Springer   }
10290f241638SMatthias Springer };
10300f241638SMatthias Springer 
10310f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
10320f241638SMatthias Springer template <>
10330f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
1034*6825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10350f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
10360f241638SMatthias Springer                                   ValueRange /*loopState*/) {
10370f241638SMatthias Springer     SmallVector<Value, 8> indices;
1038*6825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
1039*6825bfe2SNicolas Vasilache     Value ivI32 =
1040*6825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10410f241638SMatthias Springer 
10420f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
10430f241638SMatthias Springer     generateInBoundsCheck(
1044*6825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
1045*6825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
1046*6825bfe2SNicolas Vasilache           auto val =
1047*6825bfe2SNicolas Vasilache               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
1048*6825bfe2SNicolas Vasilache           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
10490f241638SMatthias Springer         });
1050*6825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
10510f241638SMatthias Springer   }
10520f241638SMatthias Springer 
1053*6825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
1054*6825bfe2SNicolas Vasilache     return Value();
1055*6825bfe2SNicolas Vasilache   }
10560f241638SMatthias Springer };
10570f241638SMatthias Springer 
10580f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
10590f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
10600f241638SMatthias Springer   int64_t offset;
10610f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
10620f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
10630f241638SMatthias Springer   return succeeded(successStrides) && strides.back() == 1;
10640f241638SMatthias Springer }
10650f241638SMatthias Springer 
10660f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
10670f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
10680f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
10690f241638SMatthias Springer ///
10700f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
10710f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
10720f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
10730f241638SMatthias Springer ///
10740f241638SMatthias Springer /// This pattern generates IR as follows:
10750f241638SMatthias Springer ///
10760f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
10770f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
10780f241638SMatthias Springer ///    depending on OpTy.
10790f241638SMatthias Springer ///
10800f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
10810f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
10820f241638SMatthias Springer ///       to ConvertVectorToLLVM.
10830f241638SMatthias Springer ///
10840f241638SMatthias Springer /// E.g.:
10850f241638SMatthias Springer /// ```
10860f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
10870f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
10880f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
10890f241638SMatthias Springer /// ```
10900f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
10910f241638SMatthias Springer /// ```
10920f241638SMatthias Springer /// for i = 0 to 9 {
10930f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
10940f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
10950f241638SMatthias Springer /// }
10960f241638SMatthias Springer /// ```
10970f241638SMatthias Springer template <typename OpTy>
10982ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
10992ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
11000f241638SMatthias Springer 
11010f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
11020f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
11030f241638SMatthias Springer     auto map = xferOp.permutation_map();
11040f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
11050f241638SMatthias Springer 
11060f241638SMatthias Springer     if (!memRefType)
11070f241638SMatthias Springer       return failure();
11080f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
11090f241638SMatthias Springer       return failure();
11100f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
11110f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
11120f241638SMatthias Springer 
11130f241638SMatthias Springer     // Loop bounds, step, state...
1114*6825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11150f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
1116*6825bfe2SNicolas Vasilache     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
1117*6825bfe2SNicolas Vasilache     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
1118*6825bfe2SNicolas Vasilache     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
1119*6825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
11200f241638SMatthias Springer 
11210f241638SMatthias Springer     // Generate for loop.
11220f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
11230f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
1124*6825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
1125*6825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
11260f241638SMatthias Springer         });
11270f241638SMatthias Springer 
11280f241638SMatthias Springer     return success();
11290f241638SMatthias Springer   }
11300f241638SMatthias Springer };
11314ead2cf7SAlex Zinenko 
1132a088bed4SMatthias Springer } // namespace lowering_1_d
1133df63eedeSBenjamin Kramer } // namespace
1134df63eedeSBenjamin Kramer 
113551d30c34SBenjamin Kramer namespace mlir {
113651d30c34SBenjamin Kramer 
11373393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns(
1138dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
11390f241638SMatthias Springer   if (options.unroll) {
1140a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1141a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
11422ca887deSMatthias Springer         patterns.getContext(), options);
11430f241638SMatthias Springer   } else {
1144a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1145a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1146a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1147a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1148a088bed4SMatthias Springer         patterns.getContext(), options);
11490f241638SMatthias Springer   }
11500f241638SMatthias Springer 
11512ca887deSMatthias Springer   if (options.targetRank == 1) {
1152a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1153a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1154a088bed4SMatthias Springer         patterns.getContext(), options);
11550f241638SMatthias Springer   }
11564ead2cf7SAlex Zinenko }
11573393cc4cSNicolas Vasilache 
11583393cc4cSNicolas Vasilache } // namespace mlir
11593393cc4cSNicolas Vasilache 
11605f9e0466SNicolas Vasilache namespace {
11615f9e0466SNicolas Vasilache 
11625f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
11635f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
11645f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
11655f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11665f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
11672ca887deSMatthias Springer     this->targetRank = options.targetRank;
1168fb7ec1f1SMatthias Springer     this->lowerPermutationMaps = options.lowerPermutationMaps;
11695f9e0466SNicolas Vasilache   }
11705f9e0466SNicolas Vasilache 
11715f9e0466SNicolas Vasilache   void runOnFunction() override {
11722ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1173fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1174fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1175fb7ec1f1SMatthias Springer     options.lowerPermutationMaps = lowerPermutationMaps;
1176fb7ec1f1SMatthias Springer 
1177fb7ec1f1SMatthias Springer     // Lower permutation maps first.
1178fb7ec1f1SMatthias Springer     if (lowerPermutationMaps) {
1179fb7ec1f1SMatthias Springer       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1180fb7ec1f1SMatthias Springer       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1181fb7ec1f1SMatthias Springer           lowerTransferPatterns);
1182fb7ec1f1SMatthias Springer       (void)applyPatternsAndFoldGreedily(getFunction(),
1183fb7ec1f1SMatthias Springer                                          std::move(lowerTransferPatterns));
1184fb7ec1f1SMatthias Springer     }
11852ca887deSMatthias Springer 
1186dc4e913bSChris Lattner     RewritePatternSet patterns(getFunction().getContext());
11872ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
1188e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
11895f9e0466SNicolas Vasilache   }
11905f9e0466SNicolas Vasilache };
11915f9e0466SNicolas Vasilache 
11925f9e0466SNicolas Vasilache } // namespace
11935f9e0466SNicolas Vasilache 
11945f9e0466SNicolas Vasilache std::unique_ptr<Pass>
11955f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11965f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
11975f9e0466SNicolas Vasilache }
1198