10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
196825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/Utils.h"
206825bfe2SNicolas Vasilache #include "mlir/Dialect/SCF/SCF.h"
214ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h"
227c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h"
234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
246825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
284ead2cf7SAlex Zinenko 
294ead2cf7SAlex Zinenko using namespace mlir;
304ead2cf7SAlex Zinenko using vector::TransferReadOp;
314ead2cf7SAlex Zinenko using vector::TransferWriteOp;
324ead2cf7SAlex Zinenko 
33350dadaaSBenjamin Kramer namespace {
340f241638SMatthias Springer 
350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
370f241638SMatthias Springer 
382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
392ca887deSMatthias Springer /// VectorTransferToSCFOptions.
402ca887deSMatthias Springer template <typename OpTy>
412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
422ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
432ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
442ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
452ca887deSMatthias Springer 
462ca887deSMatthias Springer   VectorTransferToSCFOptions options;
472ca887deSMatthias Springer };
480f241638SMatthias Springer 
490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
510f241638SMatthias Springer /// A return value of None indicates a broadcast.
520f241638SMatthias Springer template <typename OpTy>
530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
540f241638SMatthias Springer   auto map = xferOp.permutation_map();
550f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
560f241638SMatthias Springer     return expr.getPosition();
577c3c5b11SNicolas Vasilache   }
580f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
590f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
600f241638SMatthias Springer   return None;
610f241638SMatthias Springer }
620f241638SMatthias Springer 
630f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
640f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
650f241638SMatthias Springer /// omitted.
660f241638SMatthias Springer template <typename OpTy>
676825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
680f241638SMatthias Springer   auto map = xferOp.permutation_map();
690f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
706825bfe2SNicolas Vasilache                         b.getContext());
710f241638SMatthias Springer }
720f241638SMatthias Springer 
730f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
740f241638SMatthias Springer ///
750f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
760f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
770f241638SMatthias Springer ///                                 ^^^^^^
780f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
790f241638SMatthias Springer template <typename OpTy>
806825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
810f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
820f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
830f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
840f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
850f241638SMatthias Springer   auto prevIndices = adaptor.indices();
860f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
870f241638SMatthias Springer 
886825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
890f241638SMatthias Springer   bool isBroadcast = !dim.hasValue();
900f241638SMatthias Springer   if (!isBroadcast) {
916825bfe2SNicolas Vasilache     AffineExpr d0, d1;
926825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
936825bfe2SNicolas Vasilache     Value offset = adaptor.indices()[dim.getValue()];
946825bfe2SNicolas Vasilache     indices[dim.getValue()] =
956825bfe2SNicolas Vasilache         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
960f241638SMatthias Springer   }
970f241638SMatthias Springer }
980f241638SMatthias Springer 
996825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1000f241638SMatthias Springer                             Value value) {
1010f241638SMatthias Springer   if (hasRetVal) {
1026825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1030f241638SMatthias Springer   } else {
1046825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1050f241638SMatthias Springer   }
1060f241638SMatthias Springer }
1070f241638SMatthias Springer 
1080f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1090f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1100f241638SMatthias Springer /// * xferOp does not have a mask.
1110f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1120f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1130f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1140f241638SMatthias Springer template <typename OpTy>
1156825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1160f241638SMatthias Springer   if (!xferOp.mask())
1170f241638SMatthias Springer     return Value();
1180f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1190f241638SMatthias Springer     return Value();
1200f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1210f241638SMatthias Springer     return Value();
1220f241638SMatthias Springer 
1236825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1246825bfe2SNicolas Vasilache   Value ivI32 =
1256825bfe2SNicolas Vasilache       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1266825bfe2SNicolas Vasilache   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
1270f241638SMatthias Springer }
1280f241638SMatthias Springer 
1290f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1300f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1310f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1320f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1330f241638SMatthias Springer /// ```
1340f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1350f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1360f241638SMatthias Springer /// ```
1370f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1380f241638SMatthias Springer /// ```
1390f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1400f241638SMatthias Springer /// if (%a + iv < %d) {
1410f241638SMatthias Springer ///   (in-bounds case)
1420f241638SMatthias Springer /// } else {
1430f241638SMatthias Springer ///   (out-of-bounds case)
1440f241638SMatthias Springer /// }
1450f241638SMatthias Springer /// ```
1460f241638SMatthias Springer ///
1470f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1480f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1490f241638SMatthias Springer ///
1500f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1510f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1520f241638SMatthias Springer /// `resultTypes`.
1530f241638SMatthias Springer template <typename OpTy>
1540f241638SMatthias Springer static Value generateInBoundsCheck(
1556825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1560f241638SMatthias Springer     TypeRange resultTypes,
1570f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1580f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1590f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1600f241638SMatthias Springer   Value cond; // Condition to be built...
1610f241638SMatthias Springer 
1620f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1630f241638SMatthias Springer   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
1646825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1656825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1660f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
1676825bfe2SNicolas Vasilache     Value memrefDim = lb.create<memref::DimOp>(xferOp.source(), *dim);
1686825bfe2SNicolas Vasilache     AffineExpr d0, d1;
1696825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
1706825bfe2SNicolas Vasilache     Value base = xferOp.indices()[dim.getValue()];
1716825bfe2SNicolas Vasilache     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
1726825bfe2SNicolas Vasilache     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
1730f241638SMatthias Springer   }
1740f241638SMatthias Springer 
1750f241638SMatthias Springer   // Condition check 2: Masked in?
1766825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
1776825bfe2SNicolas Vasilache     if (cond)
1786825bfe2SNicolas Vasilache       cond = lb.create<AndOp>(cond, maskCond);
1796825bfe2SNicolas Vasilache     else
1800f241638SMatthias Springer       cond = maskCond;
1810f241638SMatthias Springer   }
1820f241638SMatthias Springer 
1830f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
1840f241638SMatthias Springer   if (cond) {
1856825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
1866825bfe2SNicolas Vasilache         resultTypes, cond,
1870f241638SMatthias Springer         /*thenBuilder=*/
1886825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1896825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
190cadb7ccfSAlex Zinenko         },
1910f241638SMatthias Springer         /*elseBuilder=*/
1926825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1930f241638SMatthias Springer           if (outOfBoundsCase) {
1946825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
1957c3c5b11SNicolas Vasilache           } else {
1966825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
1977c3c5b11SNicolas Vasilache           }
1987c3c5b11SNicolas Vasilache         });
1997c3c5b11SNicolas Vasilache 
2000f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2014ead2cf7SAlex Zinenko   }
2024ead2cf7SAlex Zinenko 
2030f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2046825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2050f241638SMatthias Springer }
2060f241638SMatthias Springer 
2070f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2080f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2090f241638SMatthias Springer template <typename OpTy>
2100f241638SMatthias Springer static void generateInBoundsCheck(
2116825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2120f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2130f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2140f241638SMatthias Springer   generateInBoundsCheck(
2156825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2160f241638SMatthias Springer       /*inBoundsCase=*/
2176825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2186825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2190f241638SMatthias Springer         return Value();
2200f241638SMatthias Springer       },
2210f241638SMatthias Springer       /*outOfBoundsCase=*/
2226825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2230f241638SMatthias Springer         if (outOfBoundsCase)
2246825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2250f241638SMatthias Springer         return Value();
2260f241638SMatthias Springer       });
2270f241638SMatthias Springer }
2280f241638SMatthias Springer 
2290f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
2306825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2310f241638SMatthias Springer   if (!attr)
2320f241638SMatthias Springer     return attr;
2336825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2340f241638SMatthias Springer }
2350f241638SMatthias Springer 
2360f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2370f241638SMatthias Springer /// rank.
2380f241638SMatthias Springer template <typename OpTy>
2396825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2402ca887deSMatthias Springer                                 unsigned targetRank) {
2412ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2426825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2430f241638SMatthias Springer }
2440f241638SMatthias Springer 
245a088bed4SMatthias Springer namespace lowering_n_d {
246a088bed4SMatthias Springer 
247a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
248a088bed4SMatthias Springer struct BufferAllocs {
249a088bed4SMatthias Springer   Value dataBuffer;
250a088bed4SMatthias Springer   Value maskBuffer;
251a088bed4SMatthias Springer };
252a088bed4SMatthias Springer 
253a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
254a088bed4SMatthias Springer /// TODO: Parallelism and threadlocal considerations.
255a088bed4SMatthias Springer template <typename OpTy>
2566825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
2576825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
258a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
259a088bed4SMatthias Springer   Operation *scope =
260a088bed4SMatthias Springer       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
261a088bed4SMatthias Springer   assert(scope && "Expected op to be inside automatic allocation scope");
262a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
263a088bed4SMatthias Springer 
264a088bed4SMatthias Springer   BufferAllocs result;
265a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
2666825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
267a088bed4SMatthias Springer 
268a088bed4SMatthias Springer   if (xferOp.mask()) {
269a088bed4SMatthias Springer     auto maskType = MemRefType::get({}, xferOp.mask().getType());
2706825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
271fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
2726825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
2736825bfe2SNicolas Vasilache     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
274a088bed4SMatthias Springer   }
275a088bed4SMatthias Springer 
276a088bed4SMatthias Springer   return result;
277a088bed4SMatthias Springer }
278a088bed4SMatthias Springer 
279a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
280a088bed4SMatthias Springer /// the VectorType into the MemRefType.
281a088bed4SMatthias Springer ///
282a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
283a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
284a088bed4SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
285a088bed4SMatthias Springer   auto memrefShape = type.getShape();
286a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
287a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
288a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
289a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
290a088bed4SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
291a088bed4SMatthias Springer                                          vectorType.getElementType()));
292a088bed4SMatthias Springer }
293a088bed4SMatthias Springer 
2940f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
2950f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
2960f241638SMatthias Springer template <typename OpTy>
2970f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
2980f241638SMatthias Springer   assert(xferOp.mask() && "Expected that transfer op has mask");
2990f241638SMatthias Springer   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
3000f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3010f241638SMatthias Springer   return loadOp.getMemRef();
3020f241638SMatthias Springer }
3030f241638SMatthias Springer 
3040f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3050f241638SMatthias Springer template <typename OpTy>
3060f241638SMatthias Springer struct Strategy;
3070f241638SMatthias Springer 
3080f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3094ead2cf7SAlex Zinenko template <>
3100f241638SMatthias Springer struct Strategy<TransferReadOp> {
3110f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3120f241638SMatthias Springer   /// result to the temporary buffer allocation.
3130f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3140f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3150f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3160f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3170f241638SMatthias Springer     return storeOp;
3187c3c5b11SNicolas Vasilache   }
3194ead2cf7SAlex Zinenko 
3200f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3210f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3220f241638SMatthias Springer   /// of the buffer allocation:
3230f241638SMatthias Springer   /// ```
3240f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3250f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3260f241638SMatthias Springer   /// ```
3270f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3280f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3291870e787SNicolas Vasilache   }
3300f241638SMatthias Springer 
3310f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3320f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3330f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3340f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
3350f241638SMatthias Springer     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
3360f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3370f241638SMatthias Springer   }
3380f241638SMatthias Springer 
3390f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3400f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3410f241638SMatthias Springer   ///
3420f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3430f241638SMatthias Springer   ///    variable `iv`.
3440f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3450f241638SMatthias Springer   ///
3460f241638SMatthias Springer   /// E.g.:
3470f241638SMatthias Springer   /// ```
3480f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3490f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3500f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3510f241638SMatthias Springer   /// ```
3520f241638SMatthias Springer   /// Is rewritten to:
3530f241638SMatthias Springer   /// ```
3540f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3550f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3560f241638SMatthias Springer   /// for %j = 0 to 4 {
3570f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3580f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3590f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3600f241638SMatthias Springer   /// }
3610f241638SMatthias Springer   /// ```
3620f241638SMatthias Springer   ///
3630f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3640f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3650f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
3666825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
3672ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
3682ca887deSMatthias Springer                                   TransferReadOp xferOp, Value buffer,
3692ca887deSMatthias Springer                                   Value iv) {
3700f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3710f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3720f241638SMatthias Springer     storeIndices.push_back(iv);
3730f241638SMatthias Springer 
3740f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
3756825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
3760f241638SMatthias Springer 
3776825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
3780f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3790f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
3806825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
3816825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
3826825bfe2SNicolas Vasilache         loc, vecType, xferOp.source(), xferIndices,
3836825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
3846825bfe2SNicolas Vasilache         Value(), inBoundsAttr);
3850f241638SMatthias Springer 
3866825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
3870f241638SMatthias Springer 
3886825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
3896825bfe2SNicolas Vasilache     return newXferOp;
3900f241638SMatthias Springer   }
3910f241638SMatthias Springer 
3920f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
3930f241638SMatthias Springer   /// padding value to the temporary buffer.
3946825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
3956825bfe2SNicolas Vasilache                                    Value buffer, Value iv) {
3960f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3970f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3980f241638SMatthias Springer     storeIndices.push_back(iv);
3990f241638SMatthias Springer 
4006825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4010f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4020f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4036825bfe2SNicolas Vasilache     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
4046825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
4050f241638SMatthias Springer   }
4060f241638SMatthias Springer 
4070f241638SMatthias Springer   /// Cleanup after rewriting the op.
4080f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) {
4090f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4100f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4110f241638SMatthias Springer   }
4124ead2cf7SAlex Zinenko };
4137c3c5b11SNicolas Vasilache 
4140f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4150f241638SMatthias Springer template <>
4160f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4170f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4180f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4190f241638SMatthias Springer   /// of the buffer allocation:
4200f241638SMatthias Springer   /// ```
4210f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4220f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4230f241638SMatthias Springer   /// ```
4240f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4250f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4260f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4270f241638SMatthias Springer     return loadOp.getMemRef();
4287c3c5b11SNicolas Vasilache   }
4294ead2cf7SAlex Zinenko 
4300f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4310f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4320f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4330f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4340f241638SMatthias Springer     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
4350f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4360f241638SMatthias Springer   }
4370f241638SMatthias Springer 
4380f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4390f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4400f241638SMatthias Springer   ///
4410f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4420f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4430f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4440f241638SMatthias Springer   ///    to memory.
4450f241638SMatthias Springer   ///
4460f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
4476825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
4482ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
4492ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
4502ca887deSMatthias Springer                                    Value iv) {
4510f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4520f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4530f241638SMatthias Springer     loadIndices.push_back(iv);
4540f241638SMatthias Springer 
4550f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4566825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4570f241638SMatthias Springer 
4586825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4596825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
4606825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
4616825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
4626825bfe2SNicolas Vasilache         loc, Type(), vec, xferOp.source(), xferIndices,
4636825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4640f241638SMatthias Springer         inBoundsAttr);
4650f241638SMatthias Springer 
4666825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4670f241638SMatthias Springer 
4686825bfe2SNicolas Vasilache     return newXferOp;
4690f241638SMatthias Springer   }
4700f241638SMatthias Springer 
4710f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
4726825bfe2SNicolas Vasilache   static void handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
4730f241638SMatthias Springer                                    Value buffer, Value iv) {}
4740f241638SMatthias Springer 
4750f241638SMatthias Springer   /// Cleanup after rewriting the op.
4760f241638SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) {
4770f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4780f241638SMatthias Springer   }
4790f241638SMatthias Springer };
4800f241638SMatthias Springer 
4810f241638SMatthias Springer template <typename OpTy>
482fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
483fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
4840f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
4850f241638SMatthias Springer     return failure();
486fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
4870f241638SMatthias Springer     return failure();
4888fb48979SMatthias Springer   if (xferOp.getShapedType().template isa<RankedTensorType>())
4898fb48979SMatthias Springer     return failure();
490*f718a53dSMatthias Springer   // Transfer ops that modify the element type are not supported atm.
491*f718a53dSMatthias Springer   if (xferOp.getVectorType().getElementType() !=
492*f718a53dSMatthias Springer       xferOp.getShapedType().getElementType())
493*f718a53dSMatthias Springer     return failure();
4940f241638SMatthias Springer   return success();
4950f241638SMatthias Springer }
4960f241638SMatthias Springer 
4970f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
4980f241638SMatthias Springer ///
4990f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5000f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
5010f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
5020f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
5030f241638SMatthias Springer ///    original TransferReadOp with this load.
5040f241638SMatthias Springer ///
5050f241638SMatthias Springer /// E.g.:
5060f241638SMatthias Springer /// ```
5070f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5080f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5090f241638SMatthias Springer /// ```
5100f241638SMatthias Springer /// is rewritten to:
5110f241638SMatthias Springer /// ```
5120f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5130f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5140f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5150f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5160f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5170f241638SMatthias Springer /// ```
5180f241638SMatthias Springer ///
5190f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5202ca887deSMatthias Springer struct PrepareTransferReadConversion
5212ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5222ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5230f241638SMatthias Springer 
5240f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5250f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
526fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5270f241638SMatthias Springer       return failure();
5280f241638SMatthias Springer 
5296825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5300f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5310f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5320f241638SMatthias Springer     if (xferOp.mask()) {
5330f241638SMatthias Springer       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
5340f241638SMatthias Springer           buffers.maskBuffer);
5350f241638SMatthias Springer     }
5360f241638SMatthias Springer 
5376825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5386825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
5396825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
5400f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5414ead2cf7SAlex Zinenko 
5424ead2cf7SAlex Zinenko     return success();
5434ead2cf7SAlex Zinenko   }
5440f241638SMatthias Springer };
5450f241638SMatthias Springer 
5460f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5470f241638SMatthias Springer ///
5480f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5490f241638SMatthias Springer /// 2. Store the vector into the buffer.
5500f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5510f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
5520f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
5530f241638SMatthias Springer ///
5540f241638SMatthias Springer /// E.g.:
5550f241638SMatthias Springer /// ```
5560f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
5570f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5580f241638SMatthias Springer /// ```
5590f241638SMatthias Springer /// is rewritten to:
5600f241638SMatthias Springer /// ```
5610f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5620f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
5630f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
5640f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
5650f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5660f241638SMatthias Springer /// ```
5670f241638SMatthias Springer ///
5680f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5690f241638SMatthias Springer struct PrepareTransferWriteConversion
5702ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
5712ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
5720f241638SMatthias Springer 
5730f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
5740f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
575fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5760f241638SMatthias Springer       return failure();
5770f241638SMatthias Springer 
5786825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5796825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5806825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
5816825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
5820f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
5830f241638SMatthias Springer       xferOp.vectorMutable().assign(loadedVec);
5840f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
5850f241638SMatthias Springer     });
5860f241638SMatthias Springer 
5870f241638SMatthias Springer     if (xferOp.mask()) {
5880f241638SMatthias Springer       rewriter.updateRootInPlace(
5890f241638SMatthias Springer           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
5900f241638SMatthias Springer     }
5910f241638SMatthias Springer 
5920f241638SMatthias Springer     return success();
5930f241638SMatthias Springer   }
5940f241638SMatthias Springer };
5950f241638SMatthias Springer 
5960f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
5970f241638SMatthias Springer ///
5980f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
5990f241638SMatthias Springer ///    to that new type. E.g.:
6000f241638SMatthias Springer ///    ```
6010f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
6020f241638SMatthias Springer ///    vector.transfer_write %vec ...
6030f241638SMatthias Springer ///    ```
6040f241638SMatthias Springer ///    The following cast is generated:
6050f241638SMatthias Springer ///    ```
6060f241638SMatthias Springer ///    %casted = vector.type_cast %0
6070f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6080f241638SMatthias Springer ///    ```
6090f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6100f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6110f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
6120f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
6130f241638SMatthias Springer template <typename OpTy>
6142ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6152ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6160f241638SMatthias Springer 
6170f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6180f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6190f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6200f241638SMatthias Springer       return failure();
6210f241638SMatthias Springer 
6220f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
6236825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6240f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6250f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6260f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
6276825bfe2SNicolas Vasilache     auto castedDataBuffer =
6286825bfe2SNicolas Vasilache         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6290f241638SMatthias Springer 
6300f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6310f241638SMatthias Springer     Value castedMaskBuffer;
6320f241638SMatthias Springer     if (xferOp.mask()) {
6330f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
6340f241638SMatthias Springer       auto maskBufferType =
6350f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
6360f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
6370f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
6380f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
6390f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
6400f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
6410f241638SMatthias Springer         //   be broadcasted.)
6420f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
6430f241638SMatthias Springer       } else {
6440f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
6456825bfe2SNicolas Vasilache         castedMaskBuffer =
6466825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
6470f241638SMatthias Springer       }
6480f241638SMatthias Springer     }
6490f241638SMatthias Springer 
6500f241638SMatthias Springer     // Loop bounds and step.
6516825bfe2SNicolas Vasilache     auto lb = locB.create<ConstantIndexOp>(0);
6526825bfe2SNicolas Vasilache     auto ub = locB.create<ConstantIndexOp>(
6536825bfe2SNicolas Vasilache         castedDataType.getDimSize(castedDataType.getRank() - 1));
6546825bfe2SNicolas Vasilache     auto step = locB.create<ConstantIndexOp>(1);
6550f241638SMatthias Springer 
6560f241638SMatthias Springer     // Generate for loop.
6576825bfe2SNicolas Vasilache     locB.create<scf::ForOp>(
6586825bfe2SNicolas Vasilache         lb, ub, step, ValueRange(),
6590f241638SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) {
6600f241638SMatthias Springer           generateInBoundsCheck(
6616825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
6620f241638SMatthias Springer               /*inBoundsCase=*/
6636825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
6640f241638SMatthias Springer                 // Create new transfer op.
6652ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
6662ca887deSMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv);
6670f241638SMatthias Springer 
6680f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
6690f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
6700f241638SMatthias Springer                 // the
6710f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
6720f241638SMatthias Springer                 //               needed on the new transfer op.
6730f241638SMatthias Springer                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
6740f241638SMatthias Springer                                       xferOp.getMaskType().getRank() > 1)) {
6750f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
6760f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
6770f241638SMatthias Springer 
6780f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
6790f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
6800f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
6810f241638SMatthias Springer                   // as before.
6820f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
6830f241638SMatthias Springer                     loadIndices.push_back(iv);
6840f241638SMatthias Springer 
6856825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
6866825bfe2SNicolas Vasilache                                                        loadIndices);
6870f241638SMatthias Springer                   rewriter.updateRootInPlace(
6880f241638SMatthias Springer                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
6890f241638SMatthias Springer                 }
6900f241638SMatthias Springer               },
6910f241638SMatthias Springer               /*outOfBoundsCase=*/
6920f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
6930f241638SMatthias Springer                 Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp,
6940f241638SMatthias Springer                                                      castedDataBuffer, iv);
6950f241638SMatthias Springer               });
6960f241638SMatthias Springer           b.create<scf::YieldOp>(loc);
6970f241638SMatthias Springer         });
6980f241638SMatthias Springer 
6990f241638SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp);
7000f241638SMatthias Springer     return success();
7010f241638SMatthias Springer   }
7020f241638SMatthias Springer };
7030f241638SMatthias Springer 
704a088bed4SMatthias Springer } // namespace lowering_n_d
705a088bed4SMatthias Springer 
706a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
707a088bed4SMatthias Springer 
7080f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7090f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7100f241638SMatthias Springer template <typename OpTy>
7116825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7120f241638SMatthias Springer                             int64_t i) {
7130f241638SMatthias Springer   if (!xferOp.mask())
7140f241638SMatthias Springer     return;
7150f241638SMatthias Springer 
7160f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7170f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7180f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7190f241638SMatthias Springer     newXferOp.maskMutable().assign(xferOp.mask());
7200f241638SMatthias Springer     return;
7210f241638SMatthias Springer   }
7220f241638SMatthias Springer 
7230f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
7240f241638SMatthias Springer     // Unpack one dimension of the mask.
7256825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
7266825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
7270f241638SMatthias Springer 
7280f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
7296825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
7306825bfe2SNicolas Vasilache     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
7310f241638SMatthias Springer     newXferOp.maskMutable().assign(newMask);
7320f241638SMatthias Springer   }
7330f241638SMatthias Springer 
7340f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
7350f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
7360f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
7370f241638SMatthias Springer }
7380f241638SMatthias Springer 
7390f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
7400f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
7410f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
7420f241638SMatthias Springer ///
7430f241638SMatthias Springer /// ```
7440f241638SMatthias Springer /// E.g.:
7450f241638SMatthias Springer /// ```
7460f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
7470f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
7480f241638SMatthias Springer /// ```
7490f241638SMatthias Springer /// is rewritten to IR such as (simplified):
7500f241638SMatthias Springer /// ```
7510f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
7520f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
7530f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7540f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
7550f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
7560f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7570f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
7580f241638SMatthias Springer /// ...
7590f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
7600f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
7610f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
7620f241638SMatthias Springer /// ```
7630f241638SMatthias Springer ///
7640f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
7650f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
7660f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
7672ca887deSMatthias Springer struct UnrollTransferReadConversion
7682ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
7692ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
7700f241638SMatthias Springer 
7710f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
7720f241638SMatthias Springer   /// are inserted.
7730f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
7740f241638SMatthias Springer                         PatternRewriter &rewriter) const {
7750f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
7760f241638SMatthias Springer       return insertOp.dest();
7776825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
7786825bfe2SNicolas Vasilache     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
7796825bfe2SNicolas Vasilache                                     xferOp.padding());
7800f241638SMatthias Springer   }
7810f241638SMatthias Springer 
7820f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7830f241638SMatthias Springer   /// vector::InsertOp, return that operation.
7840f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
7850f241638SMatthias Springer     if (xferOp->hasOneUse()) {
7860f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
7870f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
7880f241638SMatthias Springer         return insertOp;
7890f241638SMatthias Springer     }
7900f241638SMatthias Springer 
7910f241638SMatthias Springer     return vector::InsertOp();
7920f241638SMatthias Springer   }
7930f241638SMatthias Springer 
7940f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
7950f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
7960f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
7970f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
7980f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
7990f241638SMatthias Springer       llvm::for_each(insertOp.position(), [&](Attribute attr) {
8000f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
8010f241638SMatthias Springer       });
8020f241638SMatthias Springer     }
8030f241638SMatthias Springer   }
8040f241638SMatthias Springer 
8050f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8060f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
8070f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
8080f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
8092ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
8100f241638SMatthias Springer       return failure();
8118fb48979SMatthias Springer     if (xferOp.getShapedType().template isa<RankedTensorType>())
8128fb48979SMatthias Springer       return failure();
813*f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
814*f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
815*f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
816*f718a53dSMatthias Springer       return failure();
8170f241638SMatthias Springer 
8180f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
8190f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
8200f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
8210f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
8220f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
8230f241638SMatthias Springer                                           xferVecType.getElementType());
8240f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
8250f241638SMatthias Springer 
8260f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
8276825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8280f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
8296825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
8300f241638SMatthias Springer 
8310f241638SMatthias Springer       vec = generateInBoundsCheck(
8326825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
8330f241638SMatthias Springer           /*inBoundsCase=*/
8340f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8350f241638SMatthias Springer             // Indices for the new transfer op.
8360f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
8376825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
8380f241638SMatthias Springer 
8390f241638SMatthias Springer             // Indices for the new vector.insert op.
8400f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
8410f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
8420f241638SMatthias Springer             insertionIndices.push_back(i);
8430f241638SMatthias Springer 
8440f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
8456825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
8466825bfe2SNicolas Vasilache                 loc, newXferVecType, xferOp.source(), xferIndices,
8476825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
8486825bfe2SNicolas Vasilache                 xferOp.padding(), Value(), inBoundsAttr);
8490f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
8506825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
8516825bfe2SNicolas Vasilache                                               insertionIndices);
8520f241638SMatthias Springer           },
8530f241638SMatthias Springer           /*outOfBoundsCase=*/
8540f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
8550f241638SMatthias Springer             // Loop through original (unmodified) vector.
8560f241638SMatthias Springer             return vec;
8570f241638SMatthias Springer           });
8580f241638SMatthias Springer     }
8590f241638SMatthias Springer 
8600f241638SMatthias Springer     if (insertOp) {
8610f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
8620f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
8630f241638SMatthias Springer       rewriter.eraseOp(xferOp);
8640f241638SMatthias Springer     } else {
8650f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
8660f241638SMatthias Springer     }
8670f241638SMatthias Springer 
8680f241638SMatthias Springer     return success();
8690f241638SMatthias Springer   }
8700f241638SMatthias Springer };
8710f241638SMatthias Springer 
8720f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
8730f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
8740f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8750f241638SMatthias Springer ///
8760f241638SMatthias Springer /// ```
8770f241638SMatthias Springer /// E.g.:
8780f241638SMatthias Springer /// ```
8790f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
8800f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
8810f241638SMatthias Springer /// ```
8820f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8830f241638SMatthias Springer /// ```
8840f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
8850f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
8860f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
8870f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
8880f241638SMatthias Springer /// ...
8890f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
8900f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
8910f241638SMatthias Springer /// ```
8920f241638SMatthias Springer ///
8930f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
8940f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
8950f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
8960f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
8970f241638SMatthias Springer /// recursive application of this pattern will be minimal.
8980f241638SMatthias Springer struct UnrollTransferWriteConversion
8992ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
9002ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
9010f241638SMatthias Springer 
9020f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
9030f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
9040f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
9050f241638SMatthias Springer       return extractOp.vector();
9060f241638SMatthias Springer     return xferOp.vector();
9070f241638SMatthias Springer   }
9080f241638SMatthias Springer 
9090f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
9100f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
9110f241638SMatthias Springer     if (auto *op = xferOp.vector().getDefiningOp())
9120f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
9130f241638SMatthias Springer     return vector::ExtractOp();
9140f241638SMatthias Springer   }
9150f241638SMatthias Springer 
9160f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
9170f241638SMatthias Springer   /// indices.
9180f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
9190f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
9200f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
9210f241638SMatthias Springer       llvm::for_each(extractOp.position(), [&](Attribute attr) {
9220f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
9230f241638SMatthias Springer       });
9240f241638SMatthias Springer     }
9250f241638SMatthias Springer   }
9260f241638SMatthias Springer 
9270f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
9280f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
9290f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
9300f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
9312ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
9320f241638SMatthias Springer       return failure();
9338fb48979SMatthias Springer     if (xferOp.getShapedType().template isa<RankedTensorType>())
9348fb48979SMatthias Springer       return failure();
935*f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
936*f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
937*f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
938*f718a53dSMatthias Springer       return failure();
9390f241638SMatthias Springer 
9400f241638SMatthias Springer     auto vec = getDataVector(xferOp);
9410f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
9420f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
9430f241638SMatthias Springer 
9440f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
9456825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9460f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
9476825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
9480f241638SMatthias Springer 
9490f241638SMatthias Springer       generateInBoundsCheck(
9506825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
9510f241638SMatthias Springer           /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
9520f241638SMatthias Springer             // Indices for the new transfer op.
9530f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
9546825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
9550f241638SMatthias Springer 
9560f241638SMatthias Springer             // Indices for the new vector.extract op.
9570f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
9580f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
9590f241638SMatthias Springer             extractionIndices.push_back(i);
9600f241638SMatthias Springer 
9616825bfe2SNicolas Vasilache             auto extracted =
9626825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
9630f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
9640f241638SMatthias Springer 
9656825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
9666825bfe2SNicolas Vasilache                 loc, Type(), extracted, xferOp.source(), xferIndices,
9676825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
9686825bfe2SNicolas Vasilache                 inBoundsAttr);
9690f241638SMatthias Springer 
9700f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9710f241638SMatthias Springer           });
9720f241638SMatthias Springer     }
9730f241638SMatthias Springer 
9740f241638SMatthias Springer     rewriter.eraseOp(xferOp);
9750f241638SMatthias Springer     return success();
9760f241638SMatthias Springer   }
9770f241638SMatthias Springer };
9780f241638SMatthias Springer 
979a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
980a088bed4SMatthias Springer 
981a088bed4SMatthias Springer namespace lowering_1_d {
982a088bed4SMatthias Springer 
9830f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
9840f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
9850f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
9860f241638SMatthias Springer template <typename OpTy>
9870f241638SMatthias Springer static Optional<int64_t>
9886825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
9890f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
9900f241638SMatthias Springer   auto indices = xferOp.indices();
9910f241638SMatthias Springer   auto map = xferOp.permutation_map();
9920f241638SMatthias Springer 
9930f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
9940f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
9950f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
9960f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
9976825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9980f241638SMatthias Springer     auto dim = expr.getPosition();
9996825bfe2SNicolas Vasilache     AffineExpr d0, d1;
10006825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
10016825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
10026825bfe2SNicolas Vasilache     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
10030f241638SMatthias Springer     return dim;
10040f241638SMatthias Springer   }
10050f241638SMatthias Springer 
10060f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
10070f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
10080f241638SMatthias Springer   return None;
10090f241638SMatthias Springer }
10100f241638SMatthias Springer 
10110f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
10120f241638SMatthias Springer /// operation.
10130f241638SMatthias Springer template <typename OpTy>
10140f241638SMatthias Springer struct Strategy1d;
10150f241638SMatthias Springer 
10160f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
10170f241638SMatthias Springer template <>
10180f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
10196825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10200f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
10210f241638SMatthias Springer                                   ValueRange loopState) {
10220f241638SMatthias Springer     SmallVector<Value, 8> indices;
10236825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
10246825bfe2SNicolas Vasilache     Value ivI32 =
10256825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10260f241638SMatthias Springer     auto vec = loopState[0];
10270f241638SMatthias Springer 
10280f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
10290f241638SMatthias Springer     // padding value).
10300f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
10316825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
10320f241638SMatthias Springer         /*inBoundsCase=*/
10336825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
10346825bfe2SNicolas Vasilache           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
10356825bfe2SNicolas Vasilache           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
10360f241638SMatthias Springer         },
10370f241638SMatthias Springer         /*outOfBoundsCase=*/
10380f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
10396825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
10400f241638SMatthias Springer   }
10410f241638SMatthias Springer 
10426825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
10430f241638SMatthias Springer     // Inititalize vector with padding value.
10446825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
10456825bfe2SNicolas Vasilache     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
10460f241638SMatthias Springer   }
10470f241638SMatthias Springer };
10480f241638SMatthias Springer 
10490f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
10500f241638SMatthias Springer template <>
10510f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
10526825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
10530f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
10540f241638SMatthias Springer                                   ValueRange /*loopState*/) {
10550f241638SMatthias Springer     SmallVector<Value, 8> indices;
10566825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
10576825bfe2SNicolas Vasilache     Value ivI32 =
10586825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
10590f241638SMatthias Springer 
10600f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
10610f241638SMatthias Springer     generateInBoundsCheck(
10626825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
10636825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
10646825bfe2SNicolas Vasilache           auto val =
10656825bfe2SNicolas Vasilache               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
10666825bfe2SNicolas Vasilache           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
10670f241638SMatthias Springer         });
10686825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
10690f241638SMatthias Springer   }
10700f241638SMatthias Springer 
10716825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
10726825bfe2SNicolas Vasilache     return Value();
10736825bfe2SNicolas Vasilache   }
10740f241638SMatthias Springer };
10750f241638SMatthias Springer 
10760f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
10770f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
10780f241638SMatthias Springer   int64_t offset;
10790f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
10800f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
10815017b0f8SMatthias Springer   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
10820f241638SMatthias Springer }
10830f241638SMatthias Springer 
10840f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
10850f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
10860f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
10870f241638SMatthias Springer ///
10880f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
10890f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
10900f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
10910f241638SMatthias Springer ///
10920f241638SMatthias Springer /// This pattern generates IR as follows:
10930f241638SMatthias Springer ///
10940f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
10950f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
10960f241638SMatthias Springer ///    depending on OpTy.
10970f241638SMatthias Springer ///
10980f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
10990f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
11000f241638SMatthias Springer ///       to ConvertVectorToLLVM.
11010f241638SMatthias Springer ///
11020f241638SMatthias Springer /// E.g.:
11030f241638SMatthias Springer /// ```
11040f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
11050f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
11060f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
11070f241638SMatthias Springer /// ```
11080f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
11090f241638SMatthias Springer /// ```
11100f241638SMatthias Springer /// for i = 0 to 9 {
11110f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
11120f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
11130f241638SMatthias Springer /// }
11140f241638SMatthias Springer /// ```
11150f241638SMatthias Springer template <typename OpTy>
11162ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
11172ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
11180f241638SMatthias Springer 
11190f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
11200f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
11210f241638SMatthias Springer     auto map = xferOp.permutation_map();
11220f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
11230f241638SMatthias Springer 
11240f241638SMatthias Springer     if (!memRefType)
11250f241638SMatthias Springer       return failure();
11260f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
11270f241638SMatthias Springer       return failure();
11280f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
11290f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
11300f241638SMatthias Springer 
11310f241638SMatthias Springer     // Loop bounds, step, state...
11326825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11330f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
11346825bfe2SNicolas Vasilache     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
11356825bfe2SNicolas Vasilache     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
11366825bfe2SNicolas Vasilache     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
11376825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
11380f241638SMatthias Springer 
11390f241638SMatthias Springer     // Generate for loop.
11400f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
11410f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
11426825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
11436825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
11440f241638SMatthias Springer         });
11450f241638SMatthias Springer 
11460f241638SMatthias Springer     return success();
11470f241638SMatthias Springer   }
11480f241638SMatthias Springer };
11494ead2cf7SAlex Zinenko 
1150a088bed4SMatthias Springer } // namespace lowering_1_d
1151df63eedeSBenjamin Kramer } // namespace
1152df63eedeSBenjamin Kramer 
115351d30c34SBenjamin Kramer namespace mlir {
115451d30c34SBenjamin Kramer 
11553393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns(
1156dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
11570f241638SMatthias Springer   if (options.unroll) {
1158a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1159a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
11602ca887deSMatthias Springer         patterns.getContext(), options);
11610f241638SMatthias Springer   } else {
1162a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1163a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1164a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1165a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1166a088bed4SMatthias Springer         patterns.getContext(), options);
11670f241638SMatthias Springer   }
11680f241638SMatthias Springer 
11692ca887deSMatthias Springer   if (options.targetRank == 1) {
1170a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1171a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1172a088bed4SMatthias Springer         patterns.getContext(), options);
11730f241638SMatthias Springer   }
11744ead2cf7SAlex Zinenko }
11753393cc4cSNicolas Vasilache 
11763393cc4cSNicolas Vasilache } // namespace mlir
11773393cc4cSNicolas Vasilache 
11785f9e0466SNicolas Vasilache namespace {
11795f9e0466SNicolas Vasilache 
11805f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
11815f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
11825f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
11835f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
11845f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
11852ca887deSMatthias Springer     this->targetRank = options.targetRank;
1186fb7ec1f1SMatthias Springer     this->lowerPermutationMaps = options.lowerPermutationMaps;
11875f9e0466SNicolas Vasilache   }
11885f9e0466SNicolas Vasilache 
11895f9e0466SNicolas Vasilache   void runOnFunction() override {
11902ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1191fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1192fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1193fb7ec1f1SMatthias Springer     options.lowerPermutationMaps = lowerPermutationMaps;
1194fb7ec1f1SMatthias Springer 
1195fb7ec1f1SMatthias Springer     // Lower permutation maps first.
1196fb7ec1f1SMatthias Springer     if (lowerPermutationMaps) {
1197fb7ec1f1SMatthias Springer       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1198fb7ec1f1SMatthias Springer       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1199fb7ec1f1SMatthias Springer           lowerTransferPatterns);
1200fb7ec1f1SMatthias Springer       (void)applyPatternsAndFoldGreedily(getFunction(),
1201fb7ec1f1SMatthias Springer                                          std::move(lowerTransferPatterns));
1202fb7ec1f1SMatthias Springer     }
12032ca887deSMatthias Springer 
1204dc4e913bSChris Lattner     RewritePatternSet patterns(getFunction().getContext());
12052ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
1206e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
12075f9e0466SNicolas Vasilache   }
12085f9e0466SNicolas Vasilache };
12095f9e0466SNicolas Vasilache 
12105f9e0466SNicolas Vasilache } // namespace
12115f9e0466SNicolas Vasilache 
12125f9e0466SNicolas Vasilache std::unique_ptr<Pass>
12135f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
12145f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
12155f9e0466SNicolas Vasilache }
1216