10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
196825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/Utils.h"
206825bfe2SNicolas Vasilache #include "mlir/Dialect/SCF/SCF.h"
214ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h"
227c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h"
234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
246825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
284ead2cf7SAlex Zinenko 
294ead2cf7SAlex Zinenko using namespace mlir;
304ead2cf7SAlex Zinenko using vector::TransferReadOp;
314ead2cf7SAlex Zinenko using vector::TransferWriteOp;
324ead2cf7SAlex Zinenko 
33350dadaaSBenjamin Kramer namespace {
340f241638SMatthias Springer 
350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
370f241638SMatthias Springer 
382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
392ca887deSMatthias Springer /// VectorTransferToSCFOptions.
402ca887deSMatthias Springer template <typename OpTy>
412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
422ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
432ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
442ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
452ca887deSMatthias Springer 
462ca887deSMatthias Springer   VectorTransferToSCFOptions options;
472ca887deSMatthias Springer };
480f241638SMatthias Springer 
490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
510f241638SMatthias Springer /// A return value of None indicates a broadcast.
520f241638SMatthias Springer template <typename OpTy>
530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
540f241638SMatthias Springer   auto map = xferOp.permutation_map();
550f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
560f241638SMatthias Springer     return expr.getPosition();
577c3c5b11SNicolas Vasilache   }
580f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
590f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
600f241638SMatthias Springer   return None;
610f241638SMatthias Springer }
620f241638SMatthias Springer 
630f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
640f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
650f241638SMatthias Springer /// omitted.
660f241638SMatthias Springer template <typename OpTy>
676825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
680f241638SMatthias Springer   auto map = xferOp.permutation_map();
690f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
706825bfe2SNicolas Vasilache                         b.getContext());
710f241638SMatthias Springer }
720f241638SMatthias Springer 
730f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
740f241638SMatthias Springer ///
750f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
760f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
770f241638SMatthias Springer ///                                 ^^^^^^
780f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
790f241638SMatthias Springer template <typename OpTy>
806825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
810f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
820f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
830f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
840f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
850f241638SMatthias Springer   auto prevIndices = adaptor.indices();
860f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
870f241638SMatthias Springer 
886825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
890f241638SMatthias Springer   bool isBroadcast = !dim.hasValue();
900f241638SMatthias Springer   if (!isBroadcast) {
916825bfe2SNicolas Vasilache     AffineExpr d0, d1;
926825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
936825bfe2SNicolas Vasilache     Value offset = adaptor.indices()[dim.getValue()];
946825bfe2SNicolas Vasilache     indices[dim.getValue()] =
956825bfe2SNicolas Vasilache         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
960f241638SMatthias Springer   }
970f241638SMatthias Springer }
980f241638SMatthias Springer 
996825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1000f241638SMatthias Springer                             Value value) {
1010f241638SMatthias Springer   if (hasRetVal) {
1026825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1030f241638SMatthias Springer   } else {
1046825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1050f241638SMatthias Springer   }
1060f241638SMatthias Springer }
1070f241638SMatthias Springer 
1080f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1090f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1100f241638SMatthias Springer /// * xferOp does not have a mask.
1110f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1120f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1130f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1140f241638SMatthias Springer template <typename OpTy>
1156825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1160f241638SMatthias Springer   if (!xferOp.mask())
1170f241638SMatthias Springer     return Value();
1180f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1190f241638SMatthias Springer     return Value();
1200f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1210f241638SMatthias Springer     return Value();
1220f241638SMatthias Springer 
1236825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1246825bfe2SNicolas Vasilache   Value ivI32 =
1256825bfe2SNicolas Vasilache       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1266825bfe2SNicolas Vasilache   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
1270f241638SMatthias Springer }
1280f241638SMatthias Springer 
1290f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1300f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1310f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1320f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1330f241638SMatthias Springer /// ```
1340f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1350f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1360f241638SMatthias Springer /// ```
1370f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1380f241638SMatthias Springer /// ```
1390f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1400f241638SMatthias Springer /// if (%a + iv < %d) {
1410f241638SMatthias Springer ///   (in-bounds case)
1420f241638SMatthias Springer /// } else {
1430f241638SMatthias Springer ///   (out-of-bounds case)
1440f241638SMatthias Springer /// }
1450f241638SMatthias Springer /// ```
1460f241638SMatthias Springer ///
1470f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1480f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1490f241638SMatthias Springer ///
1500f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1510f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1520f241638SMatthias Springer /// `resultTypes`.
1530f241638SMatthias Springer template <typename OpTy>
1540f241638SMatthias Springer static Value generateInBoundsCheck(
1556825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1560f241638SMatthias Springer     TypeRange resultTypes,
1570f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1580f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1590f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1600f241638SMatthias Springer   Value cond; // Condition to be built...
1610f241638SMatthias Springer 
1620f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1630f241638SMatthias Springer   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
1646825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1656825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1660f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
1676825bfe2SNicolas Vasilache     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
1686825bfe2SNicolas Vasilache     AffineExpr d0, d1;
1696825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
1706825bfe2SNicolas Vasilache     Value base = xferOp.indices()[dim.getValue()];
1716825bfe2SNicolas Vasilache     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
1726825bfe2SNicolas Vasilache     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
1730f241638SMatthias Springer   }
1740f241638SMatthias Springer 
1750f241638SMatthias Springer   // Condition check 2: Masked in?
1766825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
1776825bfe2SNicolas Vasilache     if (cond)
1786825bfe2SNicolas Vasilache       cond = lb.create<AndOp>(cond, maskCond);
1796825bfe2SNicolas Vasilache     else
1800f241638SMatthias Springer       cond = maskCond;
1810f241638SMatthias Springer   }
1820f241638SMatthias Springer 
1830f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
1840f241638SMatthias Springer   if (cond) {
1856825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
1866825bfe2SNicolas Vasilache         resultTypes, cond,
1870f241638SMatthias Springer         /*thenBuilder=*/
1886825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1896825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
190cadb7ccfSAlex Zinenko         },
1910f241638SMatthias Springer         /*elseBuilder=*/
1926825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1930f241638SMatthias Springer           if (outOfBoundsCase) {
1946825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
1957c3c5b11SNicolas Vasilache           } else {
1966825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
1977c3c5b11SNicolas Vasilache           }
1987c3c5b11SNicolas Vasilache         });
1997c3c5b11SNicolas Vasilache 
2000f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2014ead2cf7SAlex Zinenko   }
2024ead2cf7SAlex Zinenko 
2030f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2046825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2050f241638SMatthias Springer }
2060f241638SMatthias Springer 
2070f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2080f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2090f241638SMatthias Springer template <typename OpTy>
2100f241638SMatthias Springer static void generateInBoundsCheck(
2116825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2120f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2130f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2140f241638SMatthias Springer   generateInBoundsCheck(
2156825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2160f241638SMatthias Springer       /*inBoundsCase=*/
2176825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2186825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2190f241638SMatthias Springer         return Value();
2200f241638SMatthias Springer       },
2210f241638SMatthias Springer       /*outOfBoundsCase=*/
2226825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2230f241638SMatthias Springer         if (outOfBoundsCase)
2246825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2250f241638SMatthias Springer         return Value();
2260f241638SMatthias Springer       });
2270f241638SMatthias Springer }
2280f241638SMatthias Springer 
2290f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
2306825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2310f241638SMatthias Springer   if (!attr)
2320f241638SMatthias Springer     return attr;
2336825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2340f241638SMatthias Springer }
2350f241638SMatthias Springer 
2360f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2370f241638SMatthias Springer /// rank.
2380f241638SMatthias Springer template <typename OpTy>
2396825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2402ca887deSMatthias Springer                                 unsigned targetRank) {
2412ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2426825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2430f241638SMatthias Springer }
2440f241638SMatthias Springer 
245a088bed4SMatthias Springer namespace lowering_n_d {
246a088bed4SMatthias Springer 
247a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
248a088bed4SMatthias Springer struct BufferAllocs {
249a088bed4SMatthias Springer   Value dataBuffer;
250a088bed4SMatthias Springer   Value maskBuffer;
251a088bed4SMatthias Springer };
252a088bed4SMatthias Springer 
253a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
254a088bed4SMatthias Springer /// TODO: Parallelism and threadlocal considerations.
255a088bed4SMatthias Springer template <typename OpTy>
2566825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
2576825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
258a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
259a088bed4SMatthias Springer   Operation *scope =
260a088bed4SMatthias Springer       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
261a088bed4SMatthias Springer   assert(scope && "Expected op to be inside automatic allocation scope");
262a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
263a088bed4SMatthias Springer 
264a088bed4SMatthias Springer   BufferAllocs result;
265a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
2666825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
267a088bed4SMatthias Springer 
268a088bed4SMatthias Springer   if (xferOp.mask()) {
269a088bed4SMatthias Springer     auto maskType = MemRefType::get({}, xferOp.mask().getType());
2706825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
271fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
2726825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
2736825bfe2SNicolas Vasilache     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
274a088bed4SMatthias Springer   }
275a088bed4SMatthias Springer 
276a088bed4SMatthias Springer   return result;
277a088bed4SMatthias Springer }
278a088bed4SMatthias Springer 
279a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
280a088bed4SMatthias Springer /// the VectorType into the MemRefType.
281a088bed4SMatthias Springer ///
282a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
283a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
284a088bed4SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
285a088bed4SMatthias Springer   auto memrefShape = type.getShape();
286a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
287a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
288a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
289a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
290a088bed4SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
291a088bed4SMatthias Springer                                          vectorType.getElementType()));
292a088bed4SMatthias Springer }
293a088bed4SMatthias Springer 
2940f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
2950f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
2960f241638SMatthias Springer template <typename OpTy>
2970f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
2980f241638SMatthias Springer   assert(xferOp.mask() && "Expected that transfer op has mask");
2990f241638SMatthias Springer   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
3000f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3010f241638SMatthias Springer   return loadOp.getMemRef();
3020f241638SMatthias Springer }
3030f241638SMatthias Springer 
3040f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3050f241638SMatthias Springer template <typename OpTy>
3060f241638SMatthias Springer struct Strategy;
3070f241638SMatthias Springer 
3080f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3094ead2cf7SAlex Zinenko template <>
3100f241638SMatthias Springer struct Strategy<TransferReadOp> {
3110f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3120f241638SMatthias Springer   /// result to the temporary buffer allocation.
3130f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3140f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3150f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3160f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3170f241638SMatthias Springer     return storeOp;
3187c3c5b11SNicolas Vasilache   }
3194ead2cf7SAlex Zinenko 
3200f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3210f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3220f241638SMatthias Springer   /// of the buffer allocation:
3230f241638SMatthias Springer   /// ```
3240f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3250f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3260f241638SMatthias Springer   /// ```
3270f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3280f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3291870e787SNicolas Vasilache   }
3300f241638SMatthias Springer 
3310f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3320f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3330f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3340f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
3350f241638SMatthias Springer     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
3360f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3370f241638SMatthias Springer   }
3380f241638SMatthias Springer 
3390f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3400f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3410f241638SMatthias Springer   ///
3420f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3430f241638SMatthias Springer   ///    variable `iv`.
3440f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3450f241638SMatthias Springer   ///
3460f241638SMatthias Springer   /// E.g.:
3470f241638SMatthias Springer   /// ```
3480f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3490f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3500f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3510f241638SMatthias Springer   /// ```
3520f241638SMatthias Springer   /// Is rewritten to:
3530f241638SMatthias Springer   /// ```
3540f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3550f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3560f241638SMatthias Springer   /// for %j = 0 to 4 {
3570f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3580f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3590f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3600f241638SMatthias Springer   /// }
3610f241638SMatthias Springer   /// ```
3620f241638SMatthias Springer   ///
3630f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3640f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3650f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
3666825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
3672ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
3682ca887deSMatthias Springer                                   TransferReadOp xferOp, Value buffer,
3692ca887deSMatthias Springer                                   Value iv) {
3700f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3710f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3720f241638SMatthias Springer     storeIndices.push_back(iv);
3730f241638SMatthias Springer 
3740f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
3756825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
3760f241638SMatthias Springer 
3776825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
3780f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3790f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
3806825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
3816825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
3826825bfe2SNicolas Vasilache         loc, vecType, xferOp.source(), xferIndices,
3836825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
3846825bfe2SNicolas Vasilache         Value(), inBoundsAttr);
3850f241638SMatthias Springer 
3866825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
3870f241638SMatthias Springer 
3886825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
3896825bfe2SNicolas Vasilache     return newXferOp;
3900f241638SMatthias Springer   }
3910f241638SMatthias Springer 
3920f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
3930f241638SMatthias Springer   /// padding value to the temporary buffer.
3946825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
3956825bfe2SNicolas Vasilache                                    Value buffer, Value iv) {
3960f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3970f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3980f241638SMatthias Springer     storeIndices.push_back(iv);
3990f241638SMatthias Springer 
4006825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4010f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4020f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4036825bfe2SNicolas Vasilache     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
4046825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
4050f241638SMatthias Springer   }
4060f241638SMatthias Springer 
4070f241638SMatthias Springer   /// Cleanup after rewriting the op.
4080f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
4090f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4100f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4110f241638SMatthias Springer   }
4124ead2cf7SAlex Zinenko };
4137c3c5b11SNicolas Vasilache 
4140f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4150f241638SMatthias Springer template <>
4160f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4170f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4180f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4190f241638SMatthias Springer   /// of the buffer allocation:
4200f241638SMatthias Springer   /// ```
4210f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4220f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4230f241638SMatthias Springer   /// ```
4240f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4250f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4260f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4270f241638SMatthias Springer     return loadOp.getMemRef();
4287c3c5b11SNicolas Vasilache   }
4294ead2cf7SAlex Zinenko 
4300f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4310f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4320f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4330f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4340f241638SMatthias Springer     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
4350f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4360f241638SMatthias Springer   }
4370f241638SMatthias Springer 
4380f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4390f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4400f241638SMatthias Springer   ///
4410f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4420f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4430f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4440f241638SMatthias Springer   ///    to memory.
4450f241638SMatthias Springer   ///
4460f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
4476825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
4482ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
4492ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
4502ca887deSMatthias Springer                                    Value iv) {
4510f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4520f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4530f241638SMatthias Springer     loadIndices.push_back(iv);
4540f241638SMatthias Springer 
4550f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4566825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4570f241638SMatthias Springer 
4586825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4596825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
4606825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
4616825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
4626825bfe2SNicolas Vasilache         loc, Type(), vec, xferOp.source(), xferIndices,
4636825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4640f241638SMatthias Springer         inBoundsAttr);
4650f241638SMatthias Springer 
4666825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4670f241638SMatthias Springer 
4686825bfe2SNicolas Vasilache     return newXferOp;
4690f241638SMatthias Springer   }
4700f241638SMatthias Springer 
4710f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
4726825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
4730f241638SMatthias Springer                                    Value buffer, Value iv) {}
4740f241638SMatthias Springer 
4750f241638SMatthias Springer   /// Cleanup after rewriting the op.
4760f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
4770f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4780f241638SMatthias Springer   }
4790f241638SMatthias Springer };
4800f241638SMatthias Springer 
4810f241638SMatthias Springer template <typename OpTy>
482fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
483fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
4840f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
4850f241638SMatthias Springer     return failure();
486fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
4870f241638SMatthias Springer     return failure();
488*8fb48979SMatthias Springer   if (xferOp.getShapedType().template isa<RankedTensorType>())
489*8fb48979SMatthias Springer     return failure();
4900f241638SMatthias Springer   return success();
4910f241638SMatthias Springer }
4920f241638SMatthias Springer 
4930f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
4940f241638SMatthias Springer ///
4950f241638SMatthias Springer /// 1. Allocate a temporary buffer.
4960f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
4970f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
4980f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
4990f241638SMatthias Springer ///    original TransferReadOp with this load.
5000f241638SMatthias Springer ///
5010f241638SMatthias Springer /// E.g.:
5020f241638SMatthias Springer /// ```
5030f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5040f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5050f241638SMatthias Springer /// ```
5060f241638SMatthias Springer /// is rewritten to:
5070f241638SMatthias Springer /// ```
5080f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5090f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5100f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5110f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5120f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5130f241638SMatthias Springer /// ```
5140f241638SMatthias Springer ///
5150f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5162ca887deSMatthias Springer struct PrepareTransferReadConversion
5172ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5182ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5190f241638SMatthias Springer 
5200f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5210f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
522fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5230f241638SMatthias Springer       return failure();
5240f241638SMatthias Springer 
5256825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5260f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5270f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5280f241638SMatthias Springer     if (xferOp.mask()) {
5290f241638SMatthias Springer       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
5300f241638SMatthias Springer           buffers.maskBuffer);
5310f241638SMatthias Springer     }
5320f241638SMatthias Springer 
5336825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5346825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
5356825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
5360f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5374ead2cf7SAlex Zinenko 
5384ead2cf7SAlex Zinenko     return success();
5394ead2cf7SAlex Zinenko   }
5400f241638SMatthias Springer };
5410f241638SMatthias Springer 
5420f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5430f241638SMatthias Springer ///
5440f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5450f241638SMatthias Springer /// 2. Store the vector into the buffer.
5460f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5470f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
5480f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
5490f241638SMatthias Springer ///
5500f241638SMatthias Springer /// E.g.:
5510f241638SMatthias Springer /// ```
5520f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
5530f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5540f241638SMatthias Springer /// ```
5550f241638SMatthias Springer /// is rewritten to:
5560f241638SMatthias Springer /// ```
5570f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5580f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
5590f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
5600f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
5610f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5620f241638SMatthias Springer /// ```
5630f241638SMatthias Springer ///
5640f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5650f241638SMatthias Springer struct PrepareTransferWriteConversion
5662ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
5672ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
5680f241638SMatthias Springer 
5690f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
5700f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
571fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5720f241638SMatthias Springer       return failure();
5730f241638SMatthias Springer 
5746825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5756825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5766825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
5776825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
5780f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
5790f241638SMatthias Springer       xferOp.vectorMutable().assign(loadedVec);
5800f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
5810f241638SMatthias Springer     });
5820f241638SMatthias Springer 
5830f241638SMatthias Springer     if (xferOp.mask()) {
5840f241638SMatthias Springer       rewriter.updateRootInPlace(
5850f241638SMatthias Springer           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
5860f241638SMatthias Springer     }
5870f241638SMatthias Springer 
5880f241638SMatthias Springer     return success();
5890f241638SMatthias Springer   }
5900f241638SMatthias Springer };
5910f241638SMatthias Springer 
5920f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
5930f241638SMatthias Springer ///
5940f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
5950f241638SMatthias Springer ///    to that new type. E.g.:
5960f241638SMatthias Springer ///    ```
5970f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
5980f241638SMatthias Springer ///    vector.transfer_write %vec ...
5990f241638SMatthias Springer ///    ```
6000f241638SMatthias Springer ///    The following cast is generated:
6010f241638SMatthias Springer ///    ```
6020f241638SMatthias Springer ///    %casted = vector.type_cast %0
6030f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6040f241638SMatthias Springer ///    ```
6050f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6060f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6070f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
6080f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
6090f241638SMatthias Springer template <typename OpTy>
6102ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6112ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6120f241638SMatthias Springer 
6130f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6140f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6150f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6160f241638SMatthias Springer       return failure();
6170f241638SMatthias Springer 
6180f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
6196825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6200f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6210f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6220f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
6236825bfe2SNicolas Vasilache     auto castedDataBuffer =
6246825bfe2SNicolas Vasilache         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6250f241638SMatthias Springer 
6260f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6270f241638SMatthias Springer     Value castedMaskBuffer;
6280f241638SMatthias Springer     if (xferOp.mask()) {
6290f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
6300f241638SMatthias Springer       auto maskBufferType =
6310f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
6320f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
6330f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
6340f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
6350f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
6360f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
6370f241638SMatthias Springer         //   be broadcasted.)
6380f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
6390f241638SMatthias Springer       } else {
6400f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
6416825bfe2SNicolas Vasilache         castedMaskBuffer =
6426825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
6430f241638SMatthias Springer       }
6440f241638SMatthias Springer     }
6450f241638SMatthias Springer 
6460f241638SMatthias Springer     // Loop bounds and step.
6476825bfe2SNicolas Vasilache     auto lb = locB.create<ConstantIndexOp>(0);
6486825bfe2SNicolas Vasilache     auto ub = locB.create<ConstantIndexOp>(
6496825bfe2SNicolas Vasilache         castedDataType.getDimSize(castedDataType.getRank() - 1));
6506825bfe2SNicolas Vasilache     auto step = locB.create<ConstantIndexOp>(1);
6510f241638SMatthias Springer 
6520f241638SMatthias Springer     // Generate for loop.
6536825bfe2SNicolas Vasilache     locB.create<scf::ForOp>(
6546825bfe2SNicolas Vasilache         lb, ub, step, ValueRange(),
6550f241638SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
6560f241638SMatthias Springer           generateInBoundsCheck(
6576825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
6580f241638SMatthias Springer               /*inBoundsCase=*/
6596825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
6600f241638SMatthias Springer                 // Create new transfer op.
6612ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
6622ca887deSMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv);
6630f241638SMatthias Springer 
6640f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
6650f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
6660f241638SMatthias Springer                 // the
6670f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
6680f241638SMatthias Springer                 //               needed on the new transfer op.
6690f241638SMatthias Springer                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
6700f241638SMatthias Springer                                       xferOp.getMaskType().getRank() > 1)) {
6710f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
6720f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
6730f241638SMatthias Springer 
6740f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
6750f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
6760f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
6770f241638SMatthias Springer                   // as before.
6780f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
6790f241638SMatthias Springer                     loadIndices.push_back(iv);
6800f241638SMatthias Springer 
6816825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
6826825bfe2SNicolas Vasilache                                                        loadIndices);
6830f241638SMatthias Springer                   rewriter.updateRootInPlace(
6840f241638SMatthias Springer                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
6850f241638SMatthias Springer                 }
6860f241638SMatthias Springer               },
6870f241638SMatthias Springer               /*outOfBoundsCase=*/
6880f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
6890f241638SMatthias Springer                 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
6900f241638SMatthias Springer                                                      castedDataBuffer, iv);
6910f241638SMatthias Springer               });
6920f241638SMatthias Springer           b.create<scf::YieldOp>(loc);
6930f241638SMatthias Springer         });
6940f241638SMatthias Springer 
6950f241638SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp);
6960f241638SMatthias Springer     return success();
6970f241638SMatthias Springer   }
6980f241638SMatthias Springer };
6990f241638SMatthias Springer 
700a088bed4SMatthias Springer } // namespace lowering_n_d
701a088bed4SMatthias Springer 
702a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
703a088bed4SMatthias Springer 
7040f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7050f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7060f241638SMatthias Springer template <typename OpTy>
7076825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7080f241638SMatthias Springer                             int64_t i) {
7090f241638SMatthias Springer   if (!xferOp.mask())
7100f241638SMatthias Springer     return;
7110f241638SMatthias Springer 
7120f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7130f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7140f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7150f241638SMatthias Springer     newXferOp.maskMutable().assign(xferOp.mask());
7160f241638SMatthias Springer     return;
7170f241638SMatthias Springer   }
7180f241638SMatthias Springer 
7190f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
7200f241638SMatthias Springer     // Unpack one dimension of the mask.
7216825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
7226825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
7230f241638SMatthias Springer 
7240f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
7256825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
7266825bfe2SNicolas Vasilache     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
7270f241638SMatthias Springer     newXferOp.maskMutable().assign(newMask);
7280f241638SMatthias Springer   }
7290f241638SMatthias Springer 
7300f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
7310f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
7320f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
7330f241638SMatthias Springer }
7340f241638SMatthias Springer 
7350f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
7360f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
7370f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
7380f241638SMatthias Springer ///
7390f241638SMatthias Springer /// ```
7400f241638SMatthias Springer /// E.g.:
7410f241638SMatthias Springer /// ```
7420f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
7430f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
7440f241638SMatthias Springer /// ```
7450f241638SMatthias Springer /// is rewritten to IR such as (simplified):
7460f241638SMatthias Springer /// ```
7470f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
7480f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
7490f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7500f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
7510f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
7520f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7530f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
7540f241638SMatthias Springer /// ...
7550f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
7560f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7570f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
7580f241638SMatthias Springer /// ```
7590f241638SMatthias Springer ///
7600f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
7610f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
7620f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
7632ca887deSMatthias Springer struct UnrollTransferReadConversion
7642ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
7652ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
7660f241638SMatthias Springer 
7670f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
7680f241638SMatthias Springer   /// are inserted.
7690f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
7700f241638SMatthias Springer                         PatternRewriter &rewriter) const {
7710f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
7720f241638SMatthias Springer       return insertOp.dest();
7736825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
7746825bfe2SNicolas Vasilache     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
7756825bfe2SNicolas Vasilache                                     xferOp.padding());
7760f241638SMatthias Springer   }
7770f241638SMatthias Springer 
7780f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7790f241638SMatthias Springer   /// vector::InsertOp, return that operation.
7800f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
7810f241638SMatthias Springer     if (xferOp->hasOneUse()) {
7820f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
7830f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
7840f241638SMatthias Springer         return insertOp;
7850f241638SMatthias Springer     }
7860f241638SMatthias Springer 
7870f241638SMatthias Springer     return vector::InsertOp();
7880f241638SMatthias Springer   }
7890f241638SMatthias Springer 
7900f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7910f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
7920f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
7930f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
7940f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
7950f241638SMatthias Springer       llvm::for_each(insertOp.position(), [&](Attribute attr) {
7960f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
7970f241638SMatthias Springer       });
7980f241638SMatthias Springer     }
7990f241638SMatthias Springer   }
8000f241638SMatthias Springer 
8010f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8020f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
8030f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
8040f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
8052ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
8060f241638SMatthias Springer       return failure();
807*8fb48979SMatthias Springer     if (xferOp.getShapedType().template isa<RankedTensorType>())
808*8fb48979SMatthias Springer       return failure();
8090f241638SMatthias Springer 
8100f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
8110f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
8120f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
8130f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
8140f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
8150f241638SMatthias Springer                                           xferVecType.getElementType());
8160f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
8170f241638SMatthias Springer 
8180f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
8196825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8200f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
8216825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
8220f241638SMatthias Springer 
8230f241638SMatthias Springer       vec = generateInBoundsCheck(
8246825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
8250f241638SMatthias Springer           /*inBoundsCase=*/
8260f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8270f241638SMatthias Springer             // Indices for the new transfer op.
8280f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
8296825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
8300f241638SMatthias Springer 
8310f241638SMatthias Springer             // Indices for the new vector.insert op.
8320f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
8330f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
8340f241638SMatthias Springer             insertionIndices.push_back(i);
8350f241638SMatthias Springer 
8360f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
8376825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
8386825bfe2SNicolas Vasilache                 loc, newXferVecType, xferOp.source(), xferIndices,
8396825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
8406825bfe2SNicolas Vasilache                 xferOp.padding(), Value(), inBoundsAttr);
8410f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
8426825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
8436825bfe2SNicolas Vasilache                                               insertionIndices);
8440f241638SMatthias Springer           },
8450f241638SMatthias Springer           /*outOfBoundsCase=*/
8460f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8470f241638SMatthias Springer             // Loop through original (unmodified) vector.
8480f241638SMatthias Springer             return vec;
8490f241638SMatthias Springer           });
8500f241638SMatthias Springer     }
8510f241638SMatthias Springer 
8520f241638SMatthias Springer     if (insertOp) {
8530f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
8540f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
8550f241638SMatthias Springer       rewriter.eraseOp(xferOp);
8560f241638SMatthias Springer     } else {
8570f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
8580f241638SMatthias Springer     }
8590f241638SMatthias Springer 
8600f241638SMatthias Springer     return success();
8610f241638SMatthias Springer   }
8620f241638SMatthias Springer };
8630f241638SMatthias Springer 
8640f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
8650f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
8660f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8670f241638SMatthias Springer ///
8680f241638SMatthias Springer /// ```
8690f241638SMatthias Springer /// E.g.:
8700f241638SMatthias Springer /// ```
8710f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
8720f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
8730f241638SMatthias Springer /// ```
8740f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8750f241638SMatthias Springer /// ```
8760f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
8770f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
8780f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
8790f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
8800f241638SMatthias Springer /// ...
8810f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
8820f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
8830f241638SMatthias Springer /// ```
8840f241638SMatthias Springer ///
8850f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
8860f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
8870f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
8880f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
8890f241638SMatthias Springer /// recursive application of this pattern will be minimal.
8900f241638SMatthias Springer struct UnrollTransferWriteConversion
8912ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
8922ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
8930f241638SMatthias Springer 
8940f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
8950f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
8960f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
8970f241638SMatthias Springer       return extractOp.vector();
8980f241638SMatthias Springer     return xferOp.vector();
8990f241638SMatthias Springer   }
9000f241638SMatthias Springer 
9010f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
9020f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
9030f241638SMatthias Springer     if (auto *op = xferOp.vector().getDefiningOp())
9040f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
9050f241638SMatthias Springer     return vector::ExtractOp();
9060f241638SMatthias Springer   }
9070f241638SMatthias Springer 
9080f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
9090f241638SMatthias Springer   /// indices.
9100f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
9110f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
9120f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
9130f241638SMatthias Springer       llvm::for_each(extractOp.position(), [&](Attribute attr) {
9140f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
9150f241638SMatthias Springer       });
9160f241638SMatthias Springer     }
9170f241638SMatthias Springer   }
9180f241638SMatthias Springer 
9190f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
9200f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
9210f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
9220f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
9232ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
9240f241638SMatthias Springer       return failure();
925*8fb48979SMatthias Springer     if (xferOp.getShapedType().template isa<RankedTensorType>())
926*8fb48979SMatthias Springer       return failure();
9270f241638SMatthias Springer 
9280f241638SMatthias Springer     auto vec = getDataVector(xferOp);
9290f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
9300f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
9310f241638SMatthias Springer 
9320f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
9336825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9340f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
9356825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
9360f241638SMatthias Springer 
9370f241638SMatthias Springer       generateInBoundsCheck(
9386825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
9390f241638SMatthias Springer           /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
9400f241638SMatthias Springer             // Indices for the new transfer op.
9410f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
9426825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
9430f241638SMatthias Springer 
9440f241638SMatthias Springer             // Indices for the new vector.extract op.
9450f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
9460f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
9470f241638SMatthias Springer             extractionIndices.push_back(i);
9480f241638SMatthias Springer 
9496825bfe2SNicolas Vasilache             auto extracted =
9506825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
9510f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
9520f241638SMatthias Springer 
9536825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
9546825bfe2SNicolas Vasilache                 loc, Type(), extracted, xferOp.source(), xferIndices,
9556825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
9566825bfe2SNicolas Vasilache                 inBoundsAttr);
9570f241638SMatthias Springer 
9580f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9590f241638SMatthias Springer           });
9600f241638SMatthias Springer     }
9610f241638SMatthias Springer 
9620f241638SMatthias Springer     rewriter.eraseOp(xferOp);
9630f241638SMatthias Springer     return success();
9640f241638SMatthias Springer   }
9650f241638SMatthias Springer };
9660f241638SMatthias Springer 
967a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
968a088bed4SMatthias Springer 
969a088bed4SMatthias Springer namespace lowering_1_d {
970a088bed4SMatthias Springer 
9710f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
9720f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
9730f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
9740f241638SMatthias Springer template <typename OpTy>
9750f241638SMatthias Springer static Optional<int64_t>
9766825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
9770f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
9780f241638SMatthias Springer   auto indices = xferOp.indices();
9790f241638SMatthias Springer   auto map = xferOp.permutation_map();
9800f241638SMatthias Springer 
9810f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
9820f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
9830f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
9840f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
9856825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9860f241638SMatthias Springer     auto dim = expr.getPosition();
9876825bfe2SNicolas Vasilache     AffineExpr d0, d1;
9886825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
9896825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
9906825bfe2SNicolas Vasilache     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
9910f241638SMatthias Springer     return dim;
9920f241638SMatthias Springer   }
9930f241638SMatthias Springer 
9940f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
9950f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
9960f241638SMatthias Springer   return None;
9970f241638SMatthias Springer }
9980f241638SMatthias Springer 
9990f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
10000f241638SMatthias Springer /// operation.
10010f241638SMatthias Springer template <typename OpTy>
10020f241638SMatthias Springer struct Strategy1d;
10030f241638SMatthias Springer 
10040f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
10050f241638SMatthias Springer template <>
10060f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
10076825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10080f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
10090f241638SMatthias Springer                                   ValueRange loopState) {
10100f241638SMatthias Springer     SmallVector<Value, 8> indices;
10116825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
10126825bfe2SNicolas Vasilache     Value ivI32 =
10136825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10140f241638SMatthias Springer     auto vec = loopState[0];
10150f241638SMatthias Springer 
10160f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
10170f241638SMatthias Springer     // padding value).
10180f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
10196825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
10200f241638SMatthias Springer         /*inBoundsCase=*/
10216825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
10226825bfe2SNicolas Vasilache           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
10236825bfe2SNicolas Vasilache           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
10240f241638SMatthias Springer         },
10250f241638SMatthias Springer         /*outOfBoundsCase=*/
10260f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
10276825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
10280f241638SMatthias Springer   }
10290f241638SMatthias Springer 
10306825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
10310f241638SMatthias Springer     // Inititalize vector with padding value.
10326825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
10336825bfe2SNicolas Vasilache     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
10340f241638SMatthias Springer   }
10350f241638SMatthias Springer };
10360f241638SMatthias Springer 
10370f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
10380f241638SMatthias Springer template <>
10390f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
10406825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10410f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
10420f241638SMatthias Springer                                   ValueRange /*loopState*/) {
10430f241638SMatthias Springer     SmallVector<Value, 8> indices;
10446825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
10456825bfe2SNicolas Vasilache     Value ivI32 =
10466825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10470f241638SMatthias Springer 
10480f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
10490f241638SMatthias Springer     generateInBoundsCheck(
10506825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
10516825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
10526825bfe2SNicolas Vasilache           auto val =
10536825bfe2SNicolas Vasilache               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
10546825bfe2SNicolas Vasilache           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
10550f241638SMatthias Springer         });
10566825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
10570f241638SMatthias Springer   }
10580f241638SMatthias Springer 
10596825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
10606825bfe2SNicolas Vasilache     return Value();
10616825bfe2SNicolas Vasilache   }
10620f241638SMatthias Springer };
10630f241638SMatthias Springer 
10640f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
10650f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
10660f241638SMatthias Springer   int64_t offset;
10670f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
10680f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
10690f241638SMatthias Springer   return succeeded(successStrides) && strides.back() == 1;
10700f241638SMatthias Springer }
10710f241638SMatthias Springer 
10720f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
10730f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
10740f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
10750f241638SMatthias Springer ///
10760f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
10770f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
10780f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
10790f241638SMatthias Springer ///
10800f241638SMatthias Springer /// This pattern generates IR as follows:
10810f241638SMatthias Springer ///
10820f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
10830f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
10840f241638SMatthias Springer ///    depending on OpTy.
10850f241638SMatthias Springer ///
10860f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
10870f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
10880f241638SMatthias Springer ///       to ConvertVectorToLLVM.
10890f241638SMatthias Springer ///
10900f241638SMatthias Springer /// E.g.:
10910f241638SMatthias Springer /// ```
10920f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
10930f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
10940f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
10950f241638SMatthias Springer /// ```
10960f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
10970f241638SMatthias Springer /// ```
10980f241638SMatthias Springer /// for i = 0 to 9 {
10990f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
11000f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
11010f241638SMatthias Springer /// }
11020f241638SMatthias Springer /// ```
11030f241638SMatthias Springer template <typename OpTy>
11042ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
11052ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
11060f241638SMatthias Springer 
11070f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
11080f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
11090f241638SMatthias Springer     auto map = xferOp.permutation_map();
11100f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
11110f241638SMatthias Springer 
11120f241638SMatthias Springer     if (!memRefType)
11130f241638SMatthias Springer       return failure();
11140f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
11150f241638SMatthias Springer       return failure();
11160f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
11170f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
11180f241638SMatthias Springer 
11190f241638SMatthias Springer     // Loop bounds, step, state...
11206825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11210f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
11226825bfe2SNicolas Vasilache     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
11236825bfe2SNicolas Vasilache     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
11246825bfe2SNicolas Vasilache     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
11256825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
11260f241638SMatthias Springer 
11270f241638SMatthias Springer     // Generate for loop.
11280f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
11290f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
11306825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
11316825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
11320f241638SMatthias Springer         });
11330f241638SMatthias Springer 
11340f241638SMatthias Springer     return success();
11350f241638SMatthias Springer   }
11360f241638SMatthias Springer };
11374ead2cf7SAlex Zinenko 
1138a088bed4SMatthias Springer } // namespace lowering_1_d
1139df63eedeSBenjamin Kramer } // namespace
1140df63eedeSBenjamin Kramer 
114151d30c34SBenjamin Kramer namespace mlir {
114251d30c34SBenjamin Kramer 
11433393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns(
1144dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
11450f241638SMatthias Springer   if (options.unroll) {
1146a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1147a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
11482ca887deSMatthias Springer         patterns.getContext(), options);
11490f241638SMatthias Springer   } else {
1150a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1151a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1152a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1153a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1154a088bed4SMatthias Springer         patterns.getContext(), options);
11550f241638SMatthias Springer   }
11560f241638SMatthias Springer 
11572ca887deSMatthias Springer   if (options.targetRank == 1) {
1158a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1159a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1160a088bed4SMatthias Springer         patterns.getContext(), options);
11610f241638SMatthias Springer   }
11624ead2cf7SAlex Zinenko }
11633393cc4cSNicolas Vasilache 
11643393cc4cSNicolas Vasilache } // namespace mlir
11653393cc4cSNicolas Vasilache 
11665f9e0466SNicolas Vasilache namespace {
11675f9e0466SNicolas Vasilache 
11685f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
11695f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
11705f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
11715f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11725f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
11732ca887deSMatthias Springer     this->targetRank = options.targetRank;
1174fb7ec1f1SMatthias Springer     this->lowerPermutationMaps = options.lowerPermutationMaps;
11755f9e0466SNicolas Vasilache   }
11765f9e0466SNicolas Vasilache 
11775f9e0466SNicolas Vasilache   void runOnFunction() override {
11782ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1179fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1180fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1181fb7ec1f1SMatthias Springer     options.lowerPermutationMaps = lowerPermutationMaps;
1182fb7ec1f1SMatthias Springer 
1183fb7ec1f1SMatthias Springer     // Lower permutation maps first.
1184fb7ec1f1SMatthias Springer     if (lowerPermutationMaps) {
1185fb7ec1f1SMatthias Springer       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1186fb7ec1f1SMatthias Springer       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1187fb7ec1f1SMatthias Springer           lowerTransferPatterns);
1188fb7ec1f1SMatthias Springer       (void)applyPatternsAndFoldGreedily(getFunction(),
1189fb7ec1f1SMatthias Springer                                          std::move(lowerTransferPatterns));
1190fb7ec1f1SMatthias Springer     }
11912ca887deSMatthias Springer 
1192dc4e913bSChris Lattner     RewritePatternSet patterns(getFunction().getContext());
11932ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
1194e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
11955f9e0466SNicolas Vasilache   }
11965f9e0466SNicolas Vasilache };
11975f9e0466SNicolas Vasilache 
11985f9e0466SNicolas Vasilache } // namespace
11995f9e0466SNicolas Vasilache 
12005f9e0466SNicolas Vasilache std::unique_ptr<Pass>
12015f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
12025f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
12035f9e0466SNicolas Vasilache }
1204