10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
184ead2cf7SAlex Zinenko #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
19e2310704SJulian Gross #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h"
204ead2cf7SAlex Zinenko #include "mlir/Dialect/SCF/EDSC/Intrinsics.h"
214ead2cf7SAlex Zinenko #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
224ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/EDSC/Intrinsics.h"
234ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h"
247c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h"
254ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
265f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
27b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
285f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
294ead2cf7SAlex Zinenko 
304ead2cf7SAlex Zinenko using namespace mlir;
314ead2cf7SAlex Zinenko using namespace mlir::edsc;
324ead2cf7SAlex Zinenko using namespace mlir::edsc::intrinsics;
334ead2cf7SAlex Zinenko using vector::TransferReadOp;
344ead2cf7SAlex Zinenko using vector::TransferWriteOp;
354ead2cf7SAlex Zinenko 
36350dadaaSBenjamin Kramer namespace {
370f241638SMatthias Springer 
380f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
390f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
400f241638SMatthias Springer 
41*2ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
42*2ca887deSMatthias Springer /// VectorTransferToSCFOptions.
43*2ca887deSMatthias Springer template <typename OpTy>
44*2ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
45*2ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
46*2ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
47*2ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
48*2ca887deSMatthias Springer 
49*2ca887deSMatthias Springer   VectorTransferToSCFOptions options;
50*2ca887deSMatthias Springer };
510f241638SMatthias Springer 
520f241638SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
530f241638SMatthias Springer /// the VectorType into the MemRefType.
544ead2cf7SAlex Zinenko ///
550f241638SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
560f241638SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
570f241638SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
580f241638SMatthias Springer   auto memrefShape = type.getShape();
590f241638SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
600f241638SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
610f241638SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
620f241638SMatthias Springer   return MemRefType::get(newMemrefShape,
630f241638SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
640f241638SMatthias Springer                                          vectorType.getElementType()));
654ead2cf7SAlex Zinenko }
664ead2cf7SAlex Zinenko 
670f241638SMatthias Springer /// Helper data structure for data and mask buffers.
680f241638SMatthias Springer struct BufferAllocs {
690f241638SMatthias Springer   Value dataBuffer;
700f241638SMatthias Springer   Value maskBuffer;
714ead2cf7SAlex Zinenko };
724ead2cf7SAlex Zinenko 
730f241638SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
740f241638SMatthias Springer /// TODO: Parallelism and threadlocal considerations.
750f241638SMatthias Springer template <typename OpTy>
760f241638SMatthias Springer static BufferAllocs allocBuffers(OpTy xferOp) {
77247e185dSNicolas Vasilache   auto &b = ScopedContext::getBuilderRef();
78247e185dSNicolas Vasilache   OpBuilder::InsertionGuard guard(b);
79a4b8c2deSJakub Lichman   Operation *scope =
800f241638SMatthias Springer       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
81a4b8c2deSJakub Lichman   assert(scope && "Expected op to be inside automatic allocation scope");
82a4b8c2deSJakub Lichman   b.setInsertionPointToStart(&scope->getRegion(0).front());
830f241638SMatthias Springer 
840f241638SMatthias Springer   BufferAllocs result;
850f241638SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
860f241638SMatthias Springer   result.dataBuffer = memref_alloca(bufferType).value;
870f241638SMatthias Springer 
880f241638SMatthias Springer   if (xferOp.mask()) {
890f241638SMatthias Springer     auto maskType = MemRefType::get({}, xferOp.mask().getType());
900f241638SMatthias Springer     Value maskBuffer = memref_alloca(maskType);
910f241638SMatthias Springer     memref_store(xferOp.mask(), maskBuffer);
920f241638SMatthias Springer     result.maskBuffer = memref_load(maskBuffer);
93247e185dSNicolas Vasilache   }
94247e185dSNicolas Vasilache 
950f241638SMatthias Springer   return result;
961870e787SNicolas Vasilache }
977c3c5b11SNicolas Vasilache 
980f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
990f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
1000f241638SMatthias Springer /// A return value of None indicates a broadcast.
1010f241638SMatthias Springer template <typename OpTy>
1020f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
1030f241638SMatthias Springer   auto map = xferOp.permutation_map();
1040f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
1050f241638SMatthias Springer     return expr.getPosition();
1067c3c5b11SNicolas Vasilache   }
1070f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
1080f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
1090f241638SMatthias Springer   return None;
1100f241638SMatthias Springer }
1110f241638SMatthias Springer 
1120f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
1130f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
1140f241638SMatthias Springer /// omitted.
1150f241638SMatthias Springer template <typename OpTy>
1160f241638SMatthias Springer static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) {
1170f241638SMatthias Springer   auto map = xferOp.permutation_map();
1180f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
1190f241638SMatthias Springer                         builder.getContext());
1200f241638SMatthias Springer }
1210f241638SMatthias Springer 
1220f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
1230f241638SMatthias Springer ///
1240f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
1250f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
1260f241638SMatthias Springer ///                                 ^^^^^^
1270f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
1280f241638SMatthias Springer template <typename OpTy>
1290f241638SMatthias Springer static void getXferIndices(OpTy xferOp, Value iv,
1300f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
1310f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
1320f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
1330f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
1340f241638SMatthias Springer   auto prevIndices = adaptor.indices();
1350f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
1360f241638SMatthias Springer 
1370f241638SMatthias Springer   bool isBroadcast = !dim.hasValue();
1380f241638SMatthias Springer   if (!isBroadcast) {
1390f241638SMatthias Springer     using edsc::op::operator+;
1400f241638SMatthias Springer     indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv;
1410f241638SMatthias Springer   }
1420f241638SMatthias Springer }
1430f241638SMatthias Springer 
1440f241638SMatthias Springer static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc,
1450f241638SMatthias Springer                             Value value) {
1460f241638SMatthias Springer   if (hasRetVal) {
1470f241638SMatthias Springer     builder.create<scf::YieldOp>(loc, value);
1480f241638SMatthias Springer   } else {
1490f241638SMatthias Springer     builder.create<scf::YieldOp>(loc);
1500f241638SMatthias Springer   }
1510f241638SMatthias Springer }
1520f241638SMatthias Springer 
1530f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1540f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1550f241638SMatthias Springer /// * xferOp does not have a mask.
1560f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1570f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1580f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1590f241638SMatthias Springer template <typename OpTy>
1600f241638SMatthias Springer static Value generateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) {
1610f241638SMatthias Springer   if (!xferOp.mask())
1620f241638SMatthias Springer     return Value();
1630f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1640f241638SMatthias Springer     return Value();
1650f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1660f241638SMatthias Springer     return Value();
1670f241638SMatthias Springer 
1680f241638SMatthias Springer   auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
1690f241638SMatthias Springer   return vector_extract_element(xferOp.mask(), ivI32).value;
1700f241638SMatthias Springer }
1710f241638SMatthias Springer 
1720f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1730f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1740f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1750f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1760f241638SMatthias Springer /// ```
1770f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1780f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1790f241638SMatthias Springer /// ```
1800f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1810f241638SMatthias Springer /// ```
1820f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1830f241638SMatthias Springer /// if (%a + iv < %d) {
1840f241638SMatthias Springer ///   (in-bounds case)
1850f241638SMatthias Springer /// } else {
1860f241638SMatthias Springer ///   (out-of-bounds case)
1870f241638SMatthias Springer /// }
1880f241638SMatthias Springer /// ```
1890f241638SMatthias Springer ///
1900f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1910f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1920f241638SMatthias Springer ///
1930f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1940f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1950f241638SMatthias Springer /// `resultTypes`.
1960f241638SMatthias Springer template <typename OpTy>
1970f241638SMatthias Springer static Value generateInBoundsCheck(
1980f241638SMatthias Springer     OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
1990f241638SMatthias Springer     TypeRange resultTypes,
2000f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
2010f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2020f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
2030f241638SMatthias Springer   Value cond; // Condition to be built...
2040f241638SMatthias Springer 
2050f241638SMatthias Springer   // Condition check 1: Access in-bounds?
2060f241638SMatthias Springer   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
2070f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
2080f241638SMatthias Springer     auto memrefDim =
2090f241638SMatthias Springer         memref_dim(xferOp.source(), std_constant_index(dim.getValue()));
2100f241638SMatthias Springer     using edsc::op::operator+;
2110f241638SMatthias Springer     auto memrefIdx = xferOp.indices()[dim.getValue()] + iv;
2120f241638SMatthias Springer     cond = std_cmpi_sgt(memrefDim.value, memrefIdx);
2130f241638SMatthias Springer   }
2140f241638SMatthias Springer 
2150f241638SMatthias Springer   // Condition check 2: Masked in?
2160f241638SMatthias Springer   if (auto maskCond = generateMaskCheck(builder, xferOp, iv)) {
2170f241638SMatthias Springer     if (cond) {
2180f241638SMatthias Springer       cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond);
2190f241638SMatthias Springer     } else {
2200f241638SMatthias Springer       cond = maskCond;
2210f241638SMatthias Springer     }
2220f241638SMatthias Springer   }
2230f241638SMatthias Springer 
2240f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
2250f241638SMatthias Springer   if (cond) {
2260f241638SMatthias Springer     auto check = builder.create<scf::IfOp>(
2270f241638SMatthias Springer         xferOp.getLoc(), resultTypes, cond,
2280f241638SMatthias Springer         /*thenBuilder=*/
2290f241638SMatthias Springer         [&](OpBuilder &builder, Location loc) {
2300f241638SMatthias Springer           maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc));
231cadb7ccfSAlex Zinenko         },
2320f241638SMatthias Springer         /*elseBuilder=*/
2330f241638SMatthias Springer         [&](OpBuilder &builder, Location loc) {
2340f241638SMatthias Springer           if (outOfBoundsCase) {
2350f241638SMatthias Springer             maybeYieldValue(hasRetVal, builder, loc,
2360f241638SMatthias Springer                             outOfBoundsCase(builder, loc));
2377c3c5b11SNicolas Vasilache           } else {
2380f241638SMatthias Springer             builder.create<scf::YieldOp>(loc);
2397c3c5b11SNicolas Vasilache           }
2407c3c5b11SNicolas Vasilache         });
2417c3c5b11SNicolas Vasilache 
2420f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2434ead2cf7SAlex Zinenko   }
2444ead2cf7SAlex Zinenko 
2450f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2460f241638SMatthias Springer   return inBoundsCase(builder, xferOp.getLoc());
2470f241638SMatthias Springer }
2480f241638SMatthias Springer 
2490f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2500f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2510f241638SMatthias Springer template <typename OpTy>
2520f241638SMatthias Springer static void generateInBoundsCheck(
2530f241638SMatthias Springer     OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim,
2540f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2550f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2560f241638SMatthias Springer   generateInBoundsCheck(
2570f241638SMatthias Springer       xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(),
2580f241638SMatthias Springer       /*inBoundsCase=*/
2590f241638SMatthias Springer       [&](OpBuilder &builder, Location loc) {
2600f241638SMatthias Springer         inBoundsCase(builder, loc);
2610f241638SMatthias Springer         return Value();
2620f241638SMatthias Springer       },
2630f241638SMatthias Springer       /*outOfBoundsCase=*/
2640f241638SMatthias Springer       [&](OpBuilder &builder, Location loc) {
2650f241638SMatthias Springer         if (outOfBoundsCase)
2660f241638SMatthias Springer           outOfBoundsCase(builder, loc);
2670f241638SMatthias Springer         return Value();
2680f241638SMatthias Springer       });
2690f241638SMatthias Springer }
2700f241638SMatthias Springer 
2710f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
2720f241638SMatthias Springer static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) {
2730f241638SMatthias Springer   if (!attr)
2740f241638SMatthias Springer     return attr;
2750f241638SMatthias Springer   return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front());
2760f241638SMatthias Springer }
2770f241638SMatthias Springer 
2780f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2790f241638SMatthias Springer /// rank.
2800f241638SMatthias Springer template <typename OpTy>
281*2ca887deSMatthias Springer static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp,
282*2ca887deSMatthias Springer                                 unsigned targetRank) {
283*2ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2840f241638SMatthias Springer     newXferOp->setAttr(kPassLabel, builder.getUnitAttr());
2850f241638SMatthias Springer }
2860f241638SMatthias Springer 
2870f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
2880f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
2890f241638SMatthias Springer template <typename OpTy>
2900f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
2910f241638SMatthias Springer   assert(xferOp.mask() && "Expected that transfer op has mask");
2920f241638SMatthias Springer   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
2930f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
2940f241638SMatthias Springer   return loadOp.getMemRef();
2950f241638SMatthias Springer }
2960f241638SMatthias Springer 
2970f241638SMatthias Springer /// Codegen strategy, depending on the operation.
2980f241638SMatthias Springer template <typename OpTy>
2990f241638SMatthias Springer struct Strategy;
3000f241638SMatthias Springer 
3010f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3024ead2cf7SAlex Zinenko template <>
3030f241638SMatthias Springer struct Strategy<TransferReadOp> {
3040f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3050f241638SMatthias Springer   /// result to the temporary buffer allocation.
3060f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3070f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3080f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3090f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3100f241638SMatthias Springer     return storeOp;
3117c3c5b11SNicolas Vasilache   }
3124ead2cf7SAlex Zinenko 
3130f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3140f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3150f241638SMatthias Springer   /// of the buffer allocation:
3160f241638SMatthias Springer   /// ```
3170f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3180f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3190f241638SMatthias Springer   /// ```
3200f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3210f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3221870e787SNicolas Vasilache   }
3230f241638SMatthias Springer 
3240f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3250f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3260f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3270f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
3280f241638SMatthias Springer     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
3290f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3300f241638SMatthias Springer   }
3310f241638SMatthias Springer 
3320f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3330f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3340f241638SMatthias Springer   ///
3350f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3360f241638SMatthias Springer   ///    variable `iv`.
3370f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3380f241638SMatthias Springer   ///
3390f241638SMatthias Springer   /// E.g.:
3400f241638SMatthias Springer   /// ```
3410f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3420f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3430f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3440f241638SMatthias Springer   /// ```
3450f241638SMatthias Springer   /// Is rewritten to:
3460f241638SMatthias Springer   /// ```
3470f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3480f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3490f241638SMatthias Springer   /// for %j = 0 to 4 {
3500f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3510f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3520f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3530f241638SMatthias Springer   /// }
3540f241638SMatthias Springer   /// ```
3550f241638SMatthias Springer   ///
3560f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3570f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3580f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
359*2ca887deSMatthias Springer   static TransferReadOp rewriteOp(OpBuilder &builder,
360*2ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
361*2ca887deSMatthias Springer                                   TransferReadOp xferOp, Value buffer,
362*2ca887deSMatthias Springer                                   Value iv) {
3630f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3640f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3650f241638SMatthias Springer     storeIndices.push_back(iv);
3660f241638SMatthias Springer 
3670f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
3680f241638SMatthias Springer     getXferIndices(xferOp, iv, xferIndices);
3690f241638SMatthias Springer 
3700f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3710f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
3720f241638SMatthias Springer     auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
3730f241638SMatthias Springer     auto newXfer =
3740f241638SMatthias Springer         vector_transfer_read(
3750f241638SMatthias Springer             vecType, xferOp.source(), xferIndices,
3760f241638SMatthias Springer             AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)),
3770f241638SMatthias Springer             xferOp.padding(), Value(), inBoundsAttr)
3780f241638SMatthias Springer             .value;
3790f241638SMatthias Springer 
3800f241638SMatthias Springer     maybeApplyPassLabel(builder,
381*2ca887deSMatthias Springer                         dyn_cast<TransferReadOp>(newXfer.getDefiningOp()),
382*2ca887deSMatthias Springer                         options.targetRank);
3830f241638SMatthias Springer 
3840f241638SMatthias Springer     memref_store(newXfer, buffer, storeIndices);
3850f241638SMatthias Springer     return newXfer.getDefiningOp<TransferReadOp>();
3860f241638SMatthias Springer   }
3870f241638SMatthias Springer 
3880f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
3890f241638SMatthias Springer   /// padding value to the temporary buffer.
3900f241638SMatthias Springer   static void handleOutOfBoundsDim(OpBuilder & /*builder*/,
3910f241638SMatthias Springer                                    TransferReadOp xferOp, Value buffer,
3920f241638SMatthias Springer                                    Value iv) {
3930f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3940f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3950f241638SMatthias Springer     storeIndices.push_back(iv);
3960f241638SMatthias Springer 
3970f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3980f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
3990f241638SMatthias Springer     auto vec = std_splat(vecType, xferOp.padding());
4000f241638SMatthias Springer     memref_store(vec, buffer, storeIndices);
4010f241638SMatthias Springer   }
4020f241638SMatthias Springer 
4030f241638SMatthias Springer   /// Cleanup after rewriting the op.
4040f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
4050f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4060f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4070f241638SMatthias Springer   }
4084ead2cf7SAlex Zinenko };
4097c3c5b11SNicolas Vasilache 
4100f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4110f241638SMatthias Springer template <>
4120f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4130f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4140f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4150f241638SMatthias Springer   /// of the buffer allocation:
4160f241638SMatthias Springer   /// ```
4170f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4180f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4190f241638SMatthias Springer   /// ```
4200f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4210f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4220f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4230f241638SMatthias Springer     return loadOp.getMemRef();
4247c3c5b11SNicolas Vasilache   }
4254ead2cf7SAlex Zinenko 
4260f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4270f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4280f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4290f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4300f241638SMatthias Springer     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
4310f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4320f241638SMatthias Springer   }
4330f241638SMatthias Springer 
4340f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4350f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4360f241638SMatthias Springer   ///
4370f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4380f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4390f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4400f241638SMatthias Springer   ///    to memory.
4410f241638SMatthias Springer   ///
4420f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
443*2ca887deSMatthias Springer   static TransferWriteOp rewriteOp(OpBuilder &builder,
444*2ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
445*2ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
446*2ca887deSMatthias Springer                                    Value iv) {
4470f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4480f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4490f241638SMatthias Springer     loadIndices.push_back(iv);
4500f241638SMatthias Springer 
4510f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4520f241638SMatthias Springer     getXferIndices(xferOp, iv, xferIndices);
4530f241638SMatthias Springer 
4540f241638SMatthias Springer     auto vec = memref_load(buffer, loadIndices);
4550f241638SMatthias Springer     auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr());
4560f241638SMatthias Springer     auto newXfer = vector_transfer_write(
4570f241638SMatthias Springer         Type(), vec, xferOp.source(), xferIndices,
4580f241638SMatthias Springer         AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(),
4590f241638SMatthias Springer         inBoundsAttr);
4600f241638SMatthias Springer 
461*2ca887deSMatthias Springer     maybeApplyPassLabel(builder, newXfer.op, options.targetRank);
4620f241638SMatthias Springer 
4630f241638SMatthias Springer     return newXfer;
4640f241638SMatthias Springer   }
4650f241638SMatthias Springer 
4660f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
4670f241638SMatthias Springer   static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp,
4680f241638SMatthias Springer                                    Value buffer, Value iv) {}
4690f241638SMatthias Springer 
4700f241638SMatthias Springer   /// Cleanup after rewriting the op.
4710f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
4720f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4730f241638SMatthias Springer   }
4740f241638SMatthias Springer };
4750f241638SMatthias Springer 
4760f241638SMatthias Springer template <typename OpTy>
477*2ca887deSMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) {
4780f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
4790f241638SMatthias Springer     return failure();
480*2ca887deSMatthias Springer   if (xferOp.getVectorType().getRank() <= targetRank)
4810f241638SMatthias Springer     return failure();
4820f241638SMatthias Springer   return success();
4830f241638SMatthias Springer }
4840f241638SMatthias Springer 
4850f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
4860f241638SMatthias Springer ///
4870f241638SMatthias Springer /// 1. Allocate a temporary buffer.
4880f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
4890f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
4900f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
4910f241638SMatthias Springer ///    original TransferReadOp with this load.
4920f241638SMatthias Springer ///
4930f241638SMatthias Springer /// E.g.:
4940f241638SMatthias Springer /// ```
4950f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
4960f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
4970f241638SMatthias Springer /// ```
4980f241638SMatthias Springer /// is rewritten to:
4990f241638SMatthias Springer /// ```
5000f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5010f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5020f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5030f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5040f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5050f241638SMatthias Springer /// ```
5060f241638SMatthias Springer ///
5070f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
508*2ca887deSMatthias Springer struct PrepareTransferReadConversion
509*2ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
510*2ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5110f241638SMatthias Springer 
5120f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5130f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
514*2ca887deSMatthias Springer     if (checkPrepareXferOp(xferOp, options.targetRank).failed())
5150f241638SMatthias Springer       return failure();
5160f241638SMatthias Springer 
5170f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
5180f241638SMatthias Springer     auto buffers = allocBuffers(xferOp);
5190f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5200f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5210f241638SMatthias Springer     if (xferOp.mask()) {
5220f241638SMatthias Springer       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
5230f241638SMatthias Springer           buffers.maskBuffer);
5240f241638SMatthias Springer     }
5250f241638SMatthias Springer 
5260f241638SMatthias Springer     memref_store(newXfer->getResult(0), buffers.dataBuffer);
5270f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5284ead2cf7SAlex Zinenko 
5294ead2cf7SAlex Zinenko     return success();
5304ead2cf7SAlex Zinenko   }
5310f241638SMatthias Springer };
5320f241638SMatthias Springer 
5330f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5340f241638SMatthias Springer ///
5350f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5360f241638SMatthias Springer /// 2. Store the vector into the buffer.
5370f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5380f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
5390f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
5400f241638SMatthias Springer ///
5410f241638SMatthias Springer /// E.g.:
5420f241638SMatthias Springer /// ```
5430f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
5440f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5450f241638SMatthias Springer /// ```
5460f241638SMatthias Springer /// is rewritten to:
5470f241638SMatthias Springer /// ```
5480f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5490f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
5500f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
5510f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
5520f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5530f241638SMatthias Springer /// ```
5540f241638SMatthias Springer ///
5550f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5560f241638SMatthias Springer struct PrepareTransferWriteConversion
557*2ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
558*2ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
5590f241638SMatthias Springer 
5600f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
5610f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
562*2ca887deSMatthias Springer     if (checkPrepareXferOp(xferOp, options.targetRank).failed())
5630f241638SMatthias Springer       return failure();
5640f241638SMatthias Springer 
5650f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
5660f241638SMatthias Springer     auto buffers = allocBuffers(xferOp);
5670f241638SMatthias Springer     memref_store(xferOp.vector(), buffers.dataBuffer);
5680f241638SMatthias Springer     auto loadedVec = memref_load(buffers.dataBuffer);
5690f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
5700f241638SMatthias Springer       xferOp.vectorMutable().assign(loadedVec);
5710f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
5720f241638SMatthias Springer     });
5730f241638SMatthias Springer 
5740f241638SMatthias Springer     if (xferOp.mask()) {
5750f241638SMatthias Springer       rewriter.updateRootInPlace(
5760f241638SMatthias Springer           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
5770f241638SMatthias Springer     }
5780f241638SMatthias Springer 
5790f241638SMatthias Springer     return success();
5800f241638SMatthias Springer   }
5810f241638SMatthias Springer };
5820f241638SMatthias Springer 
5830f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
5840f241638SMatthias Springer ///
5850f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
5860f241638SMatthias Springer ///    to that new type. E.g.:
5870f241638SMatthias Springer ///    ```
5880f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
5890f241638SMatthias Springer ///    vector.transfer_write %vec ...
5900f241638SMatthias Springer ///    ```
5910f241638SMatthias Springer ///    The following cast is generated:
5920f241638SMatthias Springer ///    ```
5930f241638SMatthias Springer ///    %casted = vector.type_cast %0
5940f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
5950f241638SMatthias Springer ///    ```
5960f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
5970f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
5980f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
5990f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
6000f241638SMatthias Springer template <typename OpTy>
601*2ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
602*2ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6030f241638SMatthias Springer 
6040f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6050f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6060f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6070f241638SMatthias Springer       return failure();
6080f241638SMatthias Springer 
6090f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
6100f241638SMatthias Springer 
6110f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
6120f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6130f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6140f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
6150f241638SMatthias Springer     auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer);
6160f241638SMatthias Springer 
6170f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6180f241638SMatthias Springer     Value castedMaskBuffer;
6190f241638SMatthias Springer     if (xferOp.mask()) {
6200f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
6210f241638SMatthias Springer       auto maskBufferType =
6220f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
6230f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
6240f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
6250f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
6260f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
6270f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
6280f241638SMatthias Springer         //   be broadcasted.)
6290f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
6300f241638SMatthias Springer       } else {
6310f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
6320f241638SMatthias Springer         castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer);
6330f241638SMatthias Springer       }
6340f241638SMatthias Springer     }
6350f241638SMatthias Springer 
6360f241638SMatthias Springer     // Loop bounds and step.
6370f241638SMatthias Springer     auto lb = std_constant_index(0).value;
6380f241638SMatthias Springer     auto ub = std_constant_index(
6390f241638SMatthias Springer                   castedDataType.getDimSize(castedDataType.getRank() - 1))
6400f241638SMatthias Springer                   .value;
6410f241638SMatthias Springer     auto step = std_constant_index(1).value;
6420f241638SMatthias Springer 
6430f241638SMatthias Springer     // Generate for loop.
6440f241638SMatthias Springer     rewriter.create<scf::ForOp>(
6450f241638SMatthias Springer         xferOp.getLoc(), lb, ub, step, ValueRange(),
6460f241638SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
6470f241638SMatthias Springer           ScopedContext scope(b, loc);
6480f241638SMatthias Springer           generateInBoundsCheck(
6490f241638SMatthias Springer               xferOp, iv, b, unpackedDim(xferOp),
6500f241638SMatthias Springer               /*inBoundsCase=*/
6510f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
6520f241638SMatthias Springer                 // Create new transfer op.
653*2ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
654*2ca887deSMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv);
6550f241638SMatthias Springer 
6560f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
6570f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
6580f241638SMatthias Springer                 // the
6590f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
6600f241638SMatthias Springer                 //               needed on the new transfer op.
6610f241638SMatthias Springer                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
6620f241638SMatthias Springer                                       xferOp.getMaskType().getRank() > 1)) {
6630f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
6640f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
6650f241638SMatthias Springer 
6660f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
6670f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
6680f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
6690f241638SMatthias Springer                   // as before.
6700f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
6710f241638SMatthias Springer                     loadIndices.push_back(iv);
6720f241638SMatthias Springer 
6730f241638SMatthias Springer                   auto mask = memref_load(castedMaskBuffer, loadIndices);
6740f241638SMatthias Springer                   rewriter.updateRootInPlace(
6750f241638SMatthias Springer                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
6760f241638SMatthias Springer                 }
6770f241638SMatthias Springer               },
6780f241638SMatthias Springer               /*outOfBoundsCase=*/
6790f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
6800f241638SMatthias Springer                 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
6810f241638SMatthias Springer                                                      castedDataBuffer, iv);
6820f241638SMatthias Springer               });
6830f241638SMatthias Springer           b.create<scf::YieldOp>(loc);
6840f241638SMatthias Springer         });
6850f241638SMatthias Springer 
6860f241638SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp);
6870f241638SMatthias Springer     return success();
6880f241638SMatthias Springer   }
6890f241638SMatthias Springer };
6900f241638SMatthias Springer 
6910f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
6920f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
6930f241638SMatthias Springer template <typename OpTy>
6940f241638SMatthias Springer static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp,
6950f241638SMatthias Springer                             int64_t i) {
6960f241638SMatthias Springer   if (!xferOp.mask())
6970f241638SMatthias Springer     return;
6980f241638SMatthias Springer 
6990f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7000f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7010f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7020f241638SMatthias Springer     newXferOp.maskMutable().assign(xferOp.mask());
7030f241638SMatthias Springer     return;
7040f241638SMatthias Springer   }
7050f241638SMatthias Springer 
7060f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
7070f241638SMatthias Springer     // Unpack one dimension of the mask.
7080f241638SMatthias Springer     OpBuilder::InsertionGuard guard(builder);
7090f241638SMatthias Springer     builder.setInsertionPoint(newXferOp); // Insert load before newXfer.
7100f241638SMatthias Springer 
7110f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
7120f241638SMatthias Springer     auto newMask = vector_extract(xferOp.mask(), indices).value;
7130f241638SMatthias Springer     newXferOp.maskMutable().assign(newMask);
7140f241638SMatthias Springer   }
7150f241638SMatthias Springer 
7160f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
7170f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
7180f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
7190f241638SMatthias Springer }
7200f241638SMatthias Springer 
7210f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
7220f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
7230f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
7240f241638SMatthias Springer ///
7250f241638SMatthias Springer /// ```
7260f241638SMatthias Springer /// E.g.:
7270f241638SMatthias Springer /// ```
7280f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
7290f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
7300f241638SMatthias Springer /// ```
7310f241638SMatthias Springer /// is rewritten to IR such as (simplified):
7320f241638SMatthias Springer /// ```
7330f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
7340f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
7350f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7360f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
7370f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
7380f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7390f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
7400f241638SMatthias Springer /// ...
7410f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
7420f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7430f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
7440f241638SMatthias Springer /// ```
7450f241638SMatthias Springer ///
7460f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
7470f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
7480f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
749*2ca887deSMatthias Springer struct UnrollTransferReadConversion
750*2ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
751*2ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
7520f241638SMatthias Springer 
7530f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
7540f241638SMatthias Springer   /// are inserted.
7550f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
7560f241638SMatthias Springer                         PatternRewriter &rewriter) const {
7570f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
7580f241638SMatthias Springer       return insertOp.dest();
7590f241638SMatthias Springer     return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
7600f241638SMatthias Springer   }
7610f241638SMatthias Springer 
7620f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7630f241638SMatthias Springer   /// vector::InsertOp, return that operation.
7640f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
7650f241638SMatthias Springer     if (xferOp->hasOneUse()) {
7660f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
7670f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
7680f241638SMatthias Springer         return insertOp;
7690f241638SMatthias Springer     }
7700f241638SMatthias Springer 
7710f241638SMatthias Springer     return vector::InsertOp();
7720f241638SMatthias Springer   }
7730f241638SMatthias Springer 
7740f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7750f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
7760f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
7770f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
7780f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
7790f241638SMatthias Springer       llvm::for_each(insertOp.position(), [&](Attribute attr) {
7800f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
7810f241638SMatthias Springer       });
7820f241638SMatthias Springer     }
7830f241638SMatthias Springer   }
7840f241638SMatthias Springer 
7850f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
7860f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
7870f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
7880f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
789*2ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
7900f241638SMatthias Springer       return failure();
7910f241638SMatthias Springer 
7920f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
7930f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
7940f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
7950f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
7960f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
7970f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
7980f241638SMatthias Springer                                           xferVecType.getElementType());
7990f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
8000f241638SMatthias Springer 
8010f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
8020f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
8030f241638SMatthias Springer       Value iv = std_constant_index(i);
8040f241638SMatthias Springer 
8050f241638SMatthias Springer       vec = generateInBoundsCheck(
8060f241638SMatthias Springer           xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType),
8070f241638SMatthias Springer           /*inBoundsCase=*/
8080f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8090f241638SMatthias Springer             ScopedContext scope(b, loc);
8100f241638SMatthias Springer 
8110f241638SMatthias Springer             // Indices for the new transfer op.
8120f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
8130f241638SMatthias Springer             getXferIndices(xferOp, iv, xferIndices);
8140f241638SMatthias Springer 
8150f241638SMatthias Springer             // Indices for the new vector.insert op.
8160f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
8170f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
8180f241638SMatthias Springer             insertionIndices.push_back(i);
8190f241638SMatthias Springer 
8200f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
8210f241638SMatthias Springer             auto newXferOpVal =
8220f241638SMatthias Springer                 vector_transfer_read(
8230f241638SMatthias Springer                     newXferVecType, xferOp.source(), xferIndices,
8240f241638SMatthias Springer                     AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
8250f241638SMatthias Springer                     xferOp.padding(), Value(), inBoundsAttr)
8260f241638SMatthias Springer                     .value;
8270f241638SMatthias Springer             auto newXferOp =
8280f241638SMatthias Springer                 dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp());
8290f241638SMatthias Springer 
8300f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
8310f241638SMatthias Springer 
8320f241638SMatthias Springer             return vector_insert(newXferOp, vec, insertionIndices).value;
8330f241638SMatthias Springer           },
8340f241638SMatthias Springer           /*outOfBoundsCase=*/
8350f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8360f241638SMatthias Springer             // Loop through original (unmodified) vector.
8370f241638SMatthias Springer             return vec;
8380f241638SMatthias Springer           });
8390f241638SMatthias Springer     }
8400f241638SMatthias Springer 
8410f241638SMatthias Springer     if (insertOp) {
8420f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
8430f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
8440f241638SMatthias Springer       rewriter.eraseOp(xferOp);
8450f241638SMatthias Springer     } else {
8460f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
8470f241638SMatthias Springer     }
8480f241638SMatthias Springer 
8490f241638SMatthias Springer     return success();
8500f241638SMatthias Springer   }
8510f241638SMatthias Springer };
8520f241638SMatthias Springer 
8530f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
8540f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
8550f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8560f241638SMatthias Springer ///
8570f241638SMatthias Springer /// ```
8580f241638SMatthias Springer /// E.g.:
8590f241638SMatthias Springer /// ```
8600f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
8610f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
8620f241638SMatthias Springer /// ```
8630f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8640f241638SMatthias Springer /// ```
8650f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
8660f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
8670f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
8680f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
8690f241638SMatthias Springer /// ...
8700f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
8710f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
8720f241638SMatthias Springer /// ```
8730f241638SMatthias Springer ///
8740f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
8750f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
8760f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
8770f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
8780f241638SMatthias Springer /// recursive application of this pattern will be minimal.
8790f241638SMatthias Springer struct UnrollTransferWriteConversion
880*2ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
881*2ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
8820f241638SMatthias Springer 
8830f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
8840f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
8850f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
8860f241638SMatthias Springer       return extractOp.vector();
8870f241638SMatthias Springer     return xferOp.vector();
8880f241638SMatthias Springer   }
8890f241638SMatthias Springer 
8900f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
8910f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
8920f241638SMatthias Springer     if (auto *op = xferOp.vector().getDefiningOp())
8930f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
8940f241638SMatthias Springer     return vector::ExtractOp();
8950f241638SMatthias Springer   }
8960f241638SMatthias Springer 
8970f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
8980f241638SMatthias Springer   /// indices.
8990f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
9000f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
9010f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
9020f241638SMatthias Springer       llvm::for_each(extractOp.position(), [&](Attribute attr) {
9030f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
9040f241638SMatthias Springer       });
9050f241638SMatthias Springer     }
9060f241638SMatthias Springer   }
9070f241638SMatthias Springer 
9080f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
9090f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
9100f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
9110f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
912*2ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
9130f241638SMatthias Springer       return failure();
9140f241638SMatthias Springer 
9150f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
9160f241638SMatthias Springer     auto vec = getDataVector(xferOp);
9170f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
9180f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
9190f241638SMatthias Springer 
9200f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
9210f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
9220f241638SMatthias Springer       Value iv = std_constant_index(i);
9230f241638SMatthias Springer 
9240f241638SMatthias Springer       generateInBoundsCheck(
9250f241638SMatthias Springer           xferOp, iv, rewriter, unpackedDim(xferOp),
9260f241638SMatthias Springer           /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
9270f241638SMatthias Springer             ScopedContext scope(b, loc);
9280f241638SMatthias Springer 
9290f241638SMatthias Springer             // Indices for the new transfer op.
9300f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
9310f241638SMatthias Springer             getXferIndices(xferOp, iv, xferIndices);
9320f241638SMatthias Springer 
9330f241638SMatthias Springer             // Indices for the new vector.extract op.
9340f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
9350f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
9360f241638SMatthias Springer             extractionIndices.push_back(i);
9370f241638SMatthias Springer 
9380f241638SMatthias Springer             auto extracted = vector_extract(vec, extractionIndices).value;
9390f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
9400f241638SMatthias Springer 
9410f241638SMatthias Springer             auto newXferOp =
9420f241638SMatthias Springer                 vector_transfer_write(
9430f241638SMatthias Springer                     Type(), extracted, xferOp.source(), xferIndices,
9440f241638SMatthias Springer                     AffineMapAttr::get(unpackedPermutationMap(xferOp, b)),
9450f241638SMatthias Springer                     Value(), inBoundsAttr)
9460f241638SMatthias Springer                     .op;
9470f241638SMatthias Springer 
9480f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9490f241638SMatthias Springer           });
9500f241638SMatthias Springer     }
9510f241638SMatthias Springer 
9520f241638SMatthias Springer     rewriter.eraseOp(xferOp);
9530f241638SMatthias Springer     return success();
9540f241638SMatthias Springer   }
9550f241638SMatthias Springer };
9560f241638SMatthias Springer 
9570f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
9580f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
9590f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
9600f241638SMatthias Springer template <typename OpTy>
9610f241638SMatthias Springer static Optional<int64_t>
9620f241638SMatthias Springer get1dMemrefIndices(OpTy xferOp, Value iv,
9630f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
9640f241638SMatthias Springer   auto indices = xferOp.indices();
9650f241638SMatthias Springer   auto map = xferOp.permutation_map();
9660f241638SMatthias Springer 
9670f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
9680f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
9690f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
9700f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
9710f241638SMatthias Springer     auto dim = expr.getPosition();
9720f241638SMatthias Springer     using edsc::op::operator+;
9730f241638SMatthias Springer     memrefIndices[dim] = memrefIndices[dim] + iv;
9740f241638SMatthias Springer     return dim;
9750f241638SMatthias Springer   }
9760f241638SMatthias Springer 
9770f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
9780f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
9790f241638SMatthias Springer   return None;
9800f241638SMatthias Springer }
9810f241638SMatthias Springer 
9820f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
9830f241638SMatthias Springer /// operation.
9840f241638SMatthias Springer template <typename OpTy>
9850f241638SMatthias Springer struct Strategy1d;
9860f241638SMatthias Springer 
9870f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
9880f241638SMatthias Springer template <>
9890f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
9900f241638SMatthias Springer   static void generateForLoopBody(OpBuilder &builder, Location loc,
9910f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
9920f241638SMatthias Springer                                   ValueRange loopState) {
9930f241638SMatthias Springer     SmallVector<Value, 8> indices;
9940f241638SMatthias Springer     auto dim = get1dMemrefIndices(xferOp, iv, indices);
9950f241638SMatthias Springer     auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
9960f241638SMatthias Springer     auto vec = loopState[0];
9970f241638SMatthias Springer 
9980f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
9990f241638SMatthias Springer     // padding value).
10000f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
10010f241638SMatthias Springer         xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()),
10020f241638SMatthias Springer         /*inBoundsCase=*/
10030f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) {
10040f241638SMatthias Springer           auto val = memref_load(xferOp.source(), indices);
10050f241638SMatthias Springer           return vector_insert_element(val, vec, ivI32.value).value;
10060f241638SMatthias Springer         },
10070f241638SMatthias Springer         /*outOfBoundsCase=*/
10080f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
10090f241638SMatthias Springer     builder.create<scf::YieldOp>(loc, nextVec);
10100f241638SMatthias Springer   }
10110f241638SMatthias Springer 
10120f241638SMatthias Springer   static Value initialLoopState(TransferReadOp xferOp) {
10130f241638SMatthias Springer     // Inititalize vector with padding value.
10140f241638SMatthias Springer     return std_splat(xferOp.getVectorType(), xferOp.padding()).value;
10150f241638SMatthias Springer   }
10160f241638SMatthias Springer };
10170f241638SMatthias Springer 
10180f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
10190f241638SMatthias Springer template <>
10200f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
10210f241638SMatthias Springer   static void generateForLoopBody(OpBuilder &builder, Location loc,
10220f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
10230f241638SMatthias Springer                                   ValueRange /*loopState*/) {
10240f241638SMatthias Springer     SmallVector<Value, 8> indices;
10250f241638SMatthias Springer     auto dim = get1dMemrefIndices(xferOp, iv, indices);
10260f241638SMatthias Springer     auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv);
10270f241638SMatthias Springer 
10280f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
10290f241638SMatthias Springer     generateInBoundsCheck(
10300f241638SMatthias Springer         xferOp, iv, builder, dim,
10310f241638SMatthias Springer         /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) {
10320f241638SMatthias Springer           auto val = vector_extract_element(xferOp.vector(), ivI32.value);
10330f241638SMatthias Springer           memref_store(val, xferOp.source(), indices);
10340f241638SMatthias Springer         });
10350f241638SMatthias Springer     builder.create<scf::YieldOp>(loc);
10360f241638SMatthias Springer   }
10370f241638SMatthias Springer 
10380f241638SMatthias Springer   static Value initialLoopState(TransferWriteOp xferOp) { return Value(); }
10390f241638SMatthias Springer };
10400f241638SMatthias Springer 
10410f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
10420f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
10430f241638SMatthias Springer   int64_t offset;
10440f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
10450f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
10460f241638SMatthias Springer   return succeeded(successStrides) && strides.back() == 1;
10470f241638SMatthias Springer }
10480f241638SMatthias Springer 
10490f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
10500f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
10510f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
10520f241638SMatthias Springer ///
10530f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
10540f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
10550f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
10560f241638SMatthias Springer ///
10570f241638SMatthias Springer /// This pattern generates IR as follows:
10580f241638SMatthias Springer ///
10590f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
10600f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
10610f241638SMatthias Springer ///    depending on OpTy.
10620f241638SMatthias Springer ///
10630f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
10640f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
10650f241638SMatthias Springer ///       to ConvertVectorToLLVM.
10660f241638SMatthias Springer ///
10670f241638SMatthias Springer /// E.g.:
10680f241638SMatthias Springer /// ```
10690f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
10700f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
10710f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
10720f241638SMatthias Springer /// ```
10730f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
10740f241638SMatthias Springer /// ```
10750f241638SMatthias Springer /// for i = 0 to 9 {
10760f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
10770f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
10780f241638SMatthias Springer /// }
10790f241638SMatthias Springer /// ```
10800f241638SMatthias Springer template <typename OpTy>
1081*2ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
1082*2ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
10830f241638SMatthias Springer 
10840f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
10850f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
10860f241638SMatthias Springer     ScopedContext scope(rewriter, xferOp.getLoc());
10870f241638SMatthias Springer     auto map = xferOp.permutation_map();
10880f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
10890f241638SMatthias Springer 
10900f241638SMatthias Springer     if (!memRefType)
10910f241638SMatthias Springer       return failure();
10920f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
10930f241638SMatthias Springer       return failure();
10940f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
10950f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
10960f241638SMatthias Springer 
10970f241638SMatthias Springer     // Loop bounds, step, state...
10980f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
10990f241638SMatthias Springer     auto lb = std_constant_index(0);
11000f241638SMatthias Springer     auto ub = std_constant_index(vecType.getDimSize(0));
11010f241638SMatthias Springer     auto step = std_constant_index(1);
11020f241638SMatthias Springer     auto loopState = Strategy1d<OpTy>::initialLoopState(xferOp);
11030f241638SMatthias Springer 
11040f241638SMatthias Springer     // Generate for loop.
11050f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
11060f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
11070f241638SMatthias Springer         [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) {
11080f241638SMatthias Springer           ScopedContext nestedScope(builder, loc);
11090f241638SMatthias Springer           Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv,
11100f241638SMatthias Springer                                                 loopState);
11110f241638SMatthias Springer         });
11120f241638SMatthias Springer 
11130f241638SMatthias Springer     return success();
11140f241638SMatthias Springer   }
11150f241638SMatthias Springer };
11164ead2cf7SAlex Zinenko 
1117df63eedeSBenjamin Kramer } // namespace
1118df63eedeSBenjamin Kramer 
111951d30c34SBenjamin Kramer namespace mlir {
112051d30c34SBenjamin Kramer 
11213393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns(
1122dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
11230f241638SMatthias Springer   if (options.unroll) {
11240f241638SMatthias Springer     patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>(
1125*2ca887deSMatthias Springer         patterns.getContext(), options);
11260f241638SMatthias Springer   } else {
11270f241638SMatthias Springer     patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion,
11280f241638SMatthias Springer                  TransferOpConversion<TransferReadOp>,
1129*2ca887deSMatthias Springer                  TransferOpConversion<TransferWriteOp>>(patterns.getContext(),
1130*2ca887deSMatthias Springer                                                         options);
11310f241638SMatthias Springer   }
11320f241638SMatthias Springer 
1133*2ca887deSMatthias Springer   if (options.targetRank == 1) {
11340f241638SMatthias Springer     patterns.add<TransferOp1dConversion<TransferReadOp>,
1135*2ca887deSMatthias Springer                  TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(),
1136*2ca887deSMatthias Springer                                                           options);
11370f241638SMatthias Springer   }
11384ead2cf7SAlex Zinenko }
11393393cc4cSNicolas Vasilache 
11403393cc4cSNicolas Vasilache } // namespace mlir
11413393cc4cSNicolas Vasilache 
11425f9e0466SNicolas Vasilache namespace {
11435f9e0466SNicolas Vasilache 
11445f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
11455f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
11465f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
11475f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11485f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
1149*2ca887deSMatthias Springer     this->targetRank = options.targetRank;
11505f9e0466SNicolas Vasilache   }
11515f9e0466SNicolas Vasilache 
11525f9e0466SNicolas Vasilache   void runOnFunction() override {
1153*2ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1154*2ca887deSMatthias Springer     options.setUnroll(fullUnroll);
1155*2ca887deSMatthias Springer     options.setTargetRank(targetRank);
1156*2ca887deSMatthias Springer 
1157dc4e913bSChris Lattner     RewritePatternSet patterns(getFunction().getContext());
1158*2ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
1159e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
11605f9e0466SNicolas Vasilache   }
11615f9e0466SNicolas Vasilache };
11625f9e0466SNicolas Vasilache 
11635f9e0466SNicolas Vasilache } // namespace
11645f9e0466SNicolas Vasilache 
11655f9e0466SNicolas Vasilache std::unique_ptr<Pass>
11665f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11675f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
11685f9e0466SNicolas Vasilache }
1169