10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
19a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
2066f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
246825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
284ead2cf7SAlex Zinenko 
294ead2cf7SAlex Zinenko using namespace mlir;
304ead2cf7SAlex Zinenko using vector::TransferReadOp;
314ead2cf7SAlex Zinenko using vector::TransferWriteOp;
324ead2cf7SAlex Zinenko 
33350dadaaSBenjamin Kramer namespace {
340f241638SMatthias Springer 
350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
370f241638SMatthias Springer 
382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
392ca887deSMatthias Springer /// VectorTransferToSCFOptions.
402ca887deSMatthias Springer template <typename OpTy>
412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon4d9edda10111::VectorToSCFPattern422ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
432ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
442ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
452ca887deSMatthias Springer 
462ca887deSMatthias Springer   VectorTransferToSCFOptions options;
472ca887deSMatthias Springer };
480f241638SMatthias Springer 
490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
510f241638SMatthias Springer /// A return value of None indicates a broadcast.
520f241638SMatthias Springer template <typename OpTy>
unpackedDim(OpTy xferOp)530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
54c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
55c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
567c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
570f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
580f241638SMatthias Springer     return expr.getPosition();
597c3c5b11SNicolas Vasilache   }
600f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
610f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
620f241638SMatthias Springer   return None;
630f241638SMatthias Springer }
640f241638SMatthias Springer 
650f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
660f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
670f241638SMatthias Springer /// omitted.
680f241638SMatthias Springer template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)696825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
71c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
727c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
730f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
746825bfe2SNicolas Vasilache                         b.getContext());
750f241638SMatthias Springer }
760f241638SMatthias Springer 
770f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
780f241638SMatthias Springer ///
790f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
800f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
810f241638SMatthias Springer ///                                 ^^^^^^
820f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
830f241638SMatthias Springer template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)846825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
850f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
860f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
870f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
880f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
897c38fd60SJacques Pienaar   auto prevIndices = adaptor.getIndices();
900f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
910f241638SMatthias Springer 
926825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
93491d2701SKazu Hirata   bool isBroadcast = !dim.has_value();
940f241638SMatthias Springer   if (!isBroadcast) {
956825bfe2SNicolas Vasilache     AffineExpr d0, d1;
966825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
97c27d8152SKazu Hirata     Value offset = adaptor.getIndices()[dim.value()];
98c27d8152SKazu Hirata     indices[dim.value()] =
996825bfe2SNicolas Vasilache         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1000f241638SMatthias Springer   }
1010f241638SMatthias Springer }
1020f241638SMatthias Springer 
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)1036825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1040f241638SMatthias Springer                             Value value) {
1050f241638SMatthias Springer   if (hasRetVal) {
106558e7401SMatthias Springer     assert(value && "Expected non-empty value");
1076825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1080f241638SMatthias Springer   } else {
1096825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1100f241638SMatthias Springer   }
1110f241638SMatthias Springer }
1120f241638SMatthias Springer 
1130f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1140f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1150f241638SMatthias Springer /// * xferOp does not have a mask.
1160f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1170f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1180f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1190f241638SMatthias Springer template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)1206825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1217c38fd60SJacques Pienaar   if (!xferOp.getMask())
1220f241638SMatthias Springer     return Value();
1230f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1240f241638SMatthias Springer     return Value();
1250f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1260f241638SMatthias Springer     return Value();
1270f241638SMatthias Springer 
1286825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1297c38fd60SJacques Pienaar   return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
1300f241638SMatthias Springer }
1310f241638SMatthias Springer 
1320f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1330f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1340f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1350f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1360f241638SMatthias Springer /// ```
1370f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1380f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1390f241638SMatthias Springer /// ```
1400f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1410f241638SMatthias Springer /// ```
1420f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1430f241638SMatthias Springer /// if (%a + iv < %d) {
1440f241638SMatthias Springer ///   (in-bounds case)
1450f241638SMatthias Springer /// } else {
1460f241638SMatthias Springer ///   (out-of-bounds case)
1470f241638SMatthias Springer /// }
1480f241638SMatthias Springer /// ```
1490f241638SMatthias Springer ///
1500f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1510f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1520f241638SMatthias Springer ///
1530f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1540f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1550f241638SMatthias Springer /// `resultTypes`.
1560f241638SMatthias Springer template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)1570f241638SMatthias Springer static Value generateInBoundsCheck(
1586825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1590f241638SMatthias Springer     TypeRange resultTypes,
1600f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1610f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1620f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1630f241638SMatthias Springer   Value cond; // Condition to be built...
1640f241638SMatthias Springer 
1650f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1660916d96dSKazu Hirata   bool isBroadcast = !dim; // No in-bounds check for broadcasts.
1676825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1686825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1690f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
1707c38fd60SJacques Pienaar     Value memrefDim =
1717c38fd60SJacques Pienaar         vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
1726825bfe2SNicolas Vasilache     AffineExpr d0, d1;
1736825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
1746d5fc1e3SKazu Hirata     Value base = xferOp.getIndices()[*dim];
1756825bfe2SNicolas Vasilache     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
176a54f4eaeSMogball     cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
177a54f4eaeSMogball                                     memrefIdx);
1780f241638SMatthias Springer   }
1790f241638SMatthias Springer 
1800f241638SMatthias Springer   // Condition check 2: Masked in?
1816825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
1826825bfe2SNicolas Vasilache     if (cond)
183a54f4eaeSMogball       cond = lb.create<arith::AndIOp>(cond, maskCond);
1846825bfe2SNicolas Vasilache     else
1850f241638SMatthias Springer       cond = maskCond;
1860f241638SMatthias Springer   }
1870f241638SMatthias Springer 
1880f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
1890f241638SMatthias Springer   if (cond) {
1906825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
1916825bfe2SNicolas Vasilache         resultTypes, cond,
1920f241638SMatthias Springer         /*thenBuilder=*/
1936825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1946825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
195cadb7ccfSAlex Zinenko         },
1960f241638SMatthias Springer         /*elseBuilder=*/
1976825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1980f241638SMatthias Springer           if (outOfBoundsCase) {
1996825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
2007c3c5b11SNicolas Vasilache           } else {
2016825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
2027c3c5b11SNicolas Vasilache           }
2037c3c5b11SNicolas Vasilache         });
2047c3c5b11SNicolas Vasilache 
2050f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2064ead2cf7SAlex Zinenko   }
2074ead2cf7SAlex Zinenko 
2080f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2096825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2100f241638SMatthias Springer }
2110f241638SMatthias Springer 
2120f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2130f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2140f241638SMatthias Springer template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)2150f241638SMatthias Springer static void generateInBoundsCheck(
2166825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2170f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2180f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2190f241638SMatthias Springer   generateInBoundsCheck(
2206825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2210f241638SMatthias Springer       /*inBoundsCase=*/
2226825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2236825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2240f241638SMatthias Springer         return Value();
2250f241638SMatthias Springer       },
2260f241638SMatthias Springer       /*outOfBoundsCase=*/
2276825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2280f241638SMatthias Springer         if (outOfBoundsCase)
2296825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2300f241638SMatthias Springer         return Value();
2310f241638SMatthias Springer       });
2320f241638SMatthias Springer }
2330f241638SMatthias Springer 
2340f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)2356825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2360f241638SMatthias Springer   if (!attr)
2370f241638SMatthias Springer     return attr;
2386825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2390f241638SMatthias Springer }
2400f241638SMatthias Springer 
2410f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2420f241638SMatthias Springer /// rank.
2430f241638SMatthias Springer template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)2446825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2452ca887deSMatthias Springer                                 unsigned targetRank) {
2462ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2476825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2480f241638SMatthias Springer }
2490f241638SMatthias Springer 
250558e7401SMatthias Springer /// Return true if this transfer op operates on a source tensor.
251558e7401SMatthias Springer template <typename OpTy>
isTensorOp(OpTy xferOp)252558e7401SMatthias Springer static bool isTensorOp(OpTy xferOp) {
253558e7401SMatthias Springer   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
254558e7401SMatthias Springer     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
255558e7401SMatthias Springer       // TransferWriteOps on tensors have a result.
256558e7401SMatthias Springer       assert(xferOp->getNumResults() > 0);
257558e7401SMatthias Springer     }
258558e7401SMatthias Springer     return true;
259558e7401SMatthias Springer   }
260558e7401SMatthias Springer   return false;
261558e7401SMatthias Springer }
262558e7401SMatthias Springer 
263a088bed4SMatthias Springer namespace lowering_n_d {
264a088bed4SMatthias Springer 
265a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
266a088bed4SMatthias Springer struct BufferAllocs {
267a088bed4SMatthias Springer   Value dataBuffer;
268a088bed4SMatthias Springer   Value maskBuffer;
269a088bed4SMatthias Springer };
270a088bed4SMatthias Springer 
2713c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)2723c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
2733c3810e7SNicolas Vasilache   Operation *scope =
2743c3810e7SNicolas Vasilache       op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
2753c3810e7SNicolas Vasilache   assert(scope && "Expected op to be inside automatic allocation scope");
2763c3810e7SNicolas Vasilache   return scope;
2773c3810e7SNicolas Vasilache }
2783c3810e7SNicolas Vasilache 
279a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
280a088bed4SMatthias Springer template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)2816825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
2826825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
283a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
2843c3810e7SNicolas Vasilache   Operation *scope = getAutomaticAllocationScope(xferOp);
2853c3810e7SNicolas Vasilache   assert(scope->getNumRegions() == 1 &&
2863c3810e7SNicolas Vasilache          "AutomaticAllocationScope with >1 regions");
287a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
288a088bed4SMatthias Springer 
289a088bed4SMatthias Springer   BufferAllocs result;
290a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
2916825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
292a088bed4SMatthias Springer 
2937c38fd60SJacques Pienaar   if (xferOp.getMask()) {
2947c38fd60SJacques Pienaar     auto maskType = MemRefType::get({}, xferOp.getMask().getType());
2956825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
296fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
2977c38fd60SJacques Pienaar     b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
2986825bfe2SNicolas Vasilache     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
299a088bed4SMatthias Springer   }
300a088bed4SMatthias Springer 
301a088bed4SMatthias Springer   return result;
302a088bed4SMatthias Springer }
303a088bed4SMatthias Springer 
304a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
305a088bed4SMatthias Springer /// the VectorType into the MemRefType.
306a088bed4SMatthias Springer ///
307a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)308a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
309a088bed4SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
310a088bed4SMatthias Springer   auto memrefShape = type.getShape();
311a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
312a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
313a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
314a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
315a088bed4SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
316a088bed4SMatthias Springer                                          vectorType.getElementType()));
317a088bed4SMatthias Springer }
318a088bed4SMatthias Springer 
3190f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
3200f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
3210f241638SMatthias Springer template <typename OpTy>
getMaskBuffer(OpTy xferOp)3220f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
3237c38fd60SJacques Pienaar   assert(xferOp.getMask() && "Expected that transfer op has mask");
3247c38fd60SJacques Pienaar   auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
3250f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3260f241638SMatthias Springer   return loadOp.getMemRef();
3270f241638SMatthias Springer }
3280f241638SMatthias Springer 
3290f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3300f241638SMatthias Springer template <typename OpTy>
3310f241638SMatthias Springer struct Strategy;
3320f241638SMatthias Springer 
3330f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3344ead2cf7SAlex Zinenko template <>
3350f241638SMatthias Springer struct Strategy<TransferReadOp> {
3360f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3370f241638SMatthias Springer   /// result to the temporary buffer allocation.
getStoreOp__anon4d9edda10111::lowering_n_d::Strategy3380f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3390f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3400f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3410f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3420f241638SMatthias Springer     return storeOp;
3437c3c5b11SNicolas Vasilache   }
3444ead2cf7SAlex Zinenko 
3450f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3460f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3470f241638SMatthias Springer   /// of the buffer allocation:
3480f241638SMatthias Springer   /// ```
3490f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3500f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3510f241638SMatthias Springer   /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy3520f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3530f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3541870e787SNicolas Vasilache   }
3550f241638SMatthias Springer 
3560f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy3570f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3580f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3590f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
360136d746eSJacques Pienaar     auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
3610f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3620f241638SMatthias Springer   }
3630f241638SMatthias Springer 
3640f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3650f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3660f241638SMatthias Springer   ///
3670f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3680f241638SMatthias Springer   ///    variable `iv`.
3690f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3700f241638SMatthias Springer   ///
3710f241638SMatthias Springer   /// E.g.:
3720f241638SMatthias Springer   /// ```
3730f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3740f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3750f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3760f241638SMatthias Springer   /// ```
3770f241638SMatthias Springer   /// Is rewritten to:
3780f241638SMatthias Springer   /// ```
3790f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3800f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3810f241638SMatthias Springer   /// for %j = 0 to 4 {
3820f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3830f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3840f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3850f241638SMatthias Springer   /// }
3860f241638SMatthias Springer   /// ```
3870f241638SMatthias Springer   ///
3880f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3890f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3900f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy3916825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
3922ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
393558e7401SMatthias Springer                                   TransferReadOp xferOp, Value buffer, Value iv,
394558e7401SMatthias Springer                                   ValueRange /*loopState*/) {
3950f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3960f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3970f241638SMatthias Springer     storeIndices.push_back(iv);
3980f241638SMatthias Springer 
3990f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4006825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4010f241638SMatthias Springer 
4026825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4030f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4040f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4057c38fd60SJacques Pienaar     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
4066825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
4077c38fd60SJacques Pienaar         loc, vecType, xferOp.getSource(), xferIndices,
4087c38fd60SJacques Pienaar         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
4097c38fd60SJacques Pienaar         xferOp.getPadding(), Value(), inBoundsAttr);
4100f241638SMatthias Springer 
4116825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4120f241638SMatthias Springer 
4137c38fd60SJacques Pienaar     b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
4146825bfe2SNicolas Vasilache     return newXferOp;
4150f241638SMatthias Springer   }
4160f241638SMatthias Springer 
4170f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
4180f241638SMatthias Springer   /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy419558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
420558e7401SMatthias Springer                                     Value buffer, Value iv,
421558e7401SMatthias Springer                                     ValueRange /*loopState*/) {
4220f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
4230f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
4240f241638SMatthias Springer     storeIndices.push_back(iv);
4250f241638SMatthias Springer 
4266825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4270f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4280f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4297c38fd60SJacques Pienaar     auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
4306825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
431558e7401SMatthias Springer 
432558e7401SMatthias Springer     return Value();
4330f241638SMatthias Springer   }
4340f241638SMatthias Springer 
4350f241638SMatthias Springer   /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy436558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
437558e7401SMatthias Springer                       scf::ForOp /*forOp*/) {
4380f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4390f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4400f241638SMatthias Springer   }
441558e7401SMatthias Springer 
442558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy443558e7401SMatthias Springer   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
4444ead2cf7SAlex Zinenko };
4457c3c5b11SNicolas Vasilache 
4460f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4470f241638SMatthias Springer template <>
4480f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4490f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4500f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4510f241638SMatthias Springer   /// of the buffer allocation:
4520f241638SMatthias Springer   /// ```
4530f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4540f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4550f241638SMatthias Springer   /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy4560f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4577c38fd60SJacques Pienaar     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
4580f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4590f241638SMatthias Springer     return loadOp.getMemRef();
4607c3c5b11SNicolas Vasilache   }
4614ead2cf7SAlex Zinenko 
4620f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy4630f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4640f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4657c38fd60SJacques Pienaar     auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
466136d746eSJacques Pienaar     auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
4670f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4680f241638SMatthias Springer   }
4690f241638SMatthias Springer 
4700f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4710f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4720f241638SMatthias Springer   ///
4730f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4740f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4750f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4760f241638SMatthias Springer   ///    to memory.
4770f241638SMatthias Springer   ///
4780f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy4796825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
4802ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
4812ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
482558e7401SMatthias Springer                                    Value iv, ValueRange loopState) {
4830f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4840f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4850f241638SMatthias Springer     loadIndices.push_back(iv);
4860f241638SMatthias Springer 
4870f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4886825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4890f241638SMatthias Springer 
4906825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4916825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
4927c38fd60SJacques Pienaar     auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
4937c38fd60SJacques Pienaar     auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
494558e7401SMatthias Springer     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
4956825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
496558e7401SMatthias Springer         loc, type, vec, source, xferIndices,
4976825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4980f241638SMatthias Springer         inBoundsAttr);
4990f241638SMatthias Springer 
5006825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
5010f241638SMatthias Springer 
5026825bfe2SNicolas Vasilache     return newXferOp;
5030f241638SMatthias Springer   }
5040f241638SMatthias Springer 
5050f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy506558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
507558e7401SMatthias Springer                                     Value buffer, Value iv,
508558e7401SMatthias Springer                                     ValueRange loopState) {
509558e7401SMatthias Springer     return isTensorOp(xferOp) ? loopState[0] : Value();
510558e7401SMatthias Springer   }
5110f241638SMatthias Springer 
5120f241638SMatthias Springer   /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy513558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
514558e7401SMatthias Springer                       scf::ForOp forOp) {
515558e7401SMatthias Springer     if (isTensorOp(xferOp)) {
516558e7401SMatthias Springer       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
517558e7401SMatthias Springer       rewriter.replaceOp(xferOp, forOp->getResult(0));
518558e7401SMatthias Springer     } else {
5190f241638SMatthias Springer       rewriter.eraseOp(xferOp);
5200f241638SMatthias Springer     }
521558e7401SMatthias Springer   }
522558e7401SMatthias Springer 
523558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy524558e7401SMatthias Springer   static Value initialLoopState(TransferWriteOp xferOp) {
5257c38fd60SJacques Pienaar     return isTensorOp(xferOp) ? xferOp.getSource() : Value();
526558e7401SMatthias Springer   }
5270f241638SMatthias Springer };
5280f241638SMatthias Springer 
5290f241638SMatthias Springer template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)530fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
531fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
5320f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
5330f241638SMatthias Springer     return failure();
534fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
5350f241638SMatthias Springer     return failure();
536558e7401SMatthias Springer   if (isTensorOp(xferOp) && !options.lowerTensors)
5378fb48979SMatthias Springer     return failure();
538f718a53dSMatthias Springer   // Transfer ops that modify the element type are not supported atm.
539f718a53dSMatthias Springer   if (xferOp.getVectorType().getElementType() !=
540f718a53dSMatthias Springer       xferOp.getShapedType().getElementType())
541f718a53dSMatthias Springer     return failure();
5420f241638SMatthias Springer   return success();
5430f241638SMatthias Springer }
5440f241638SMatthias Springer 
5450f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
5460f241638SMatthias Springer ///
5470f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5480f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
5490f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
5500f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
5510f241638SMatthias Springer ///    original TransferReadOp with this load.
5520f241638SMatthias Springer ///
5530f241638SMatthias Springer /// E.g.:
5540f241638SMatthias Springer /// ```
5550f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5560f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5570f241638SMatthias Springer /// ```
5580f241638SMatthias Springer /// is rewritten to:
5590f241638SMatthias Springer /// ```
5600f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5610f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5620f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5630f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5640f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5650f241638SMatthias Springer /// ```
5660f241638SMatthias Springer ///
5670f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5682ca887deSMatthias Springer struct PrepareTransferReadConversion
5692ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5702ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5710f241638SMatthias Springer 
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferReadConversion5720f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5730f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
574fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5750f241638SMatthias Springer       return failure();
5760f241638SMatthias Springer 
5776825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5780f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5790f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5807c38fd60SJacques Pienaar     if (xferOp.getMask()) {
5817c38fd60SJacques Pienaar       dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
5820f241638SMatthias Springer           buffers.maskBuffer);
5830f241638SMatthias Springer     }
5840f241638SMatthias Springer 
5856825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5866825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
5876825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
5880f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5894ead2cf7SAlex Zinenko 
5904ead2cf7SAlex Zinenko     return success();
5914ead2cf7SAlex Zinenko   }
5920f241638SMatthias Springer };
5930f241638SMatthias Springer 
5940f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5950f241638SMatthias Springer ///
5960f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5970f241638SMatthias Springer /// 2. Store the vector into the buffer.
5980f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5990f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
6000f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
6010f241638SMatthias Springer ///
6020f241638SMatthias Springer /// E.g.:
6030f241638SMatthias Springer /// ```
6040f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
6050f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
6060f241638SMatthias Springer /// ```
6070f241638SMatthias Springer /// is rewritten to:
6080f241638SMatthias Springer /// ```
6090f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
6100f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
6110f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
6120f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
6130f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
6140f241638SMatthias Springer /// ```
6150f241638SMatthias Springer ///
6160f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
6170f241638SMatthias Springer struct PrepareTransferWriteConversion
6182ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
6192ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
6200f241638SMatthias Springer 
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferWriteConversion6210f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
6220f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
623fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
6240f241638SMatthias Springer       return failure();
6250f241638SMatthias Springer 
6266825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
6276825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
6287c38fd60SJacques Pienaar     rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
6297c38fd60SJacques Pienaar                                      buffers.dataBuffer);
6306825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
6310f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
6327c38fd60SJacques Pienaar       xferOp.getVectorMutable().assign(loadedVec);
6330f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
6340f241638SMatthias Springer     });
6350f241638SMatthias Springer 
6367c38fd60SJacques Pienaar     if (xferOp.getMask()) {
6377c38fd60SJacques Pienaar       rewriter.updateRootInPlace(xferOp, [&]() {
6387c38fd60SJacques Pienaar         xferOp.getMaskMutable().assign(buffers.maskBuffer);
6397c38fd60SJacques Pienaar       });
6400f241638SMatthias Springer     }
6410f241638SMatthias Springer 
6420f241638SMatthias Springer     return success();
6430f241638SMatthias Springer   }
6440f241638SMatthias Springer };
6450f241638SMatthias Springer 
6460f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
6470f241638SMatthias Springer ///
6480f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
6490f241638SMatthias Springer ///    to that new type. E.g.:
6500f241638SMatthias Springer ///    ```
6510f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
6520f241638SMatthias Springer ///    vector.transfer_write %vec ...
6530f241638SMatthias Springer ///    ```
6540f241638SMatthias Springer ///    The following cast is generated:
6550f241638SMatthias Springer ///    ```
6560f241638SMatthias Springer ///    %casted = vector.type_cast %0
6570f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6580f241638SMatthias Springer ///    ```
6590f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6600f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6610f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
6620f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
663558e7401SMatthias Springer ///
664558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
665558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated
666558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.:
667558e7401SMatthias Springer /// ```
668558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 {
669558e7401SMatthias Springer ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
670558e7401SMatthias Springer ///   %1 = vector.transfer_write %0, %source[...]
671558e7401SMatthias Springer ///       : vector<4x3xf32>, tensor<5x4x3xf32>
672558e7401SMatthias Springer ///   scf.yield %1 : tensor<5x4x3xf32>
673558e7401SMatthias Springer /// }
674558e7401SMatthias Springer /// ```
6750f241638SMatthias Springer template <typename OpTy>
6762ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6772ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6780f241638SMatthias Springer 
initialize__anon4d9edda10111::lowering_n_d::TransferOpConversion679700b64dcSMatthias Springer   void initialize() {
680700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
681700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
682700b64dcSMatthias Springer     this->setHasBoundedRewriteRecursion();
683700b64dcSMatthias Springer   }
684700b64dcSMatthias Springer 
matchAndRewrite__anon4d9edda10111::lowering_n_d::TransferOpConversion6850f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6860f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6870f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6880f241638SMatthias Springer       return failure();
6890f241638SMatthias Springer 
6900f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
6916825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6920f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6930f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6940f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
6956825bfe2SNicolas Vasilache     auto castedDataBuffer =
6966825bfe2SNicolas Vasilache         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6970f241638SMatthias Springer 
6980f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6990f241638SMatthias Springer     Value castedMaskBuffer;
7007c38fd60SJacques Pienaar     if (xferOp.getMask()) {
7010f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
7020f241638SMatthias Springer       auto maskBufferType =
7030f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
7040f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
7050f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
7060f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
7070f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
7080f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
7090f241638SMatthias Springer         //   be broadcasted.)
7100f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
7110f241638SMatthias Springer       } else {
7120f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
7136825bfe2SNicolas Vasilache         castedMaskBuffer =
7146825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
7150f241638SMatthias Springer       }
7160f241638SMatthias Springer     }
7170f241638SMatthias Springer 
7180f241638SMatthias Springer     // Loop bounds and step.
719a54f4eaeSMogball     auto lb = locB.create<arith::ConstantIndexOp>(0);
720a54f4eaeSMogball     auto ub = locB.create<arith::ConstantIndexOp>(
7216825bfe2SNicolas Vasilache         castedDataType.getDimSize(castedDataType.getRank() - 1));
722a54f4eaeSMogball     auto step = locB.create<arith::ConstantIndexOp>(1);
723558e7401SMatthias Springer     // TransferWriteOps that operate on tensors return the modified tensor and
724558e7401SMatthias Springer     // require a loop state.
725558e7401SMatthias Springer     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
7260f241638SMatthias Springer 
7270f241638SMatthias Springer     // Generate for loop.
728558e7401SMatthias Springer     auto result = locB.create<scf::ForOp>(
729558e7401SMatthias Springer         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
730558e7401SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
731558e7401SMatthias Springer           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
732558e7401SMatthias Springer 
733558e7401SMatthias Springer           auto result = generateInBoundsCheck(
7346825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
735558e7401SMatthias Springer               stateType ? TypeRange(stateType) : TypeRange(),
7360f241638SMatthias Springer               /*inBoundsCase=*/
7376825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
7380f241638SMatthias Springer                 // Create new transfer op.
7392ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
740558e7401SMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
7410f241638SMatthias Springer 
7420f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
7430f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
7440f241638SMatthias Springer                 // the
7450f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
7460f241638SMatthias Springer                 //               needed on the new transfer op.
7477c38fd60SJacques Pienaar                 if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
7480f241638SMatthias Springer                                          xferOp.getMaskType().getRank() > 1)) {
7490f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
7500f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
7510f241638SMatthias Springer 
7520f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
7530f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
7540f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
7550f241638SMatthias Springer                   // as before.
7560f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
7570f241638SMatthias Springer                     loadIndices.push_back(iv);
7580f241638SMatthias Springer 
7596825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
7606825bfe2SNicolas Vasilache                                                        loadIndices);
7617c38fd60SJacques Pienaar                   rewriter.updateRootInPlace(newXfer, [&]() {
7627c38fd60SJacques Pienaar                     newXfer.getMaskMutable().assign(mask);
7637c38fd60SJacques Pienaar                   });
7640f241638SMatthias Springer                 }
765558e7401SMatthias Springer 
766558e7401SMatthias Springer                 return loopState.empty() ? Value() : newXfer->getResult(0);
7670f241638SMatthias Springer               },
7680f241638SMatthias Springer               /*outOfBoundsCase=*/
7690f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
770558e7401SMatthias Springer                 return Strategy<OpTy>::handleOutOfBoundsDim(
771558e7401SMatthias Springer                     b, xferOp, castedDataBuffer, iv, loopState);
7720f241638SMatthias Springer               });
7730f241638SMatthias Springer 
774558e7401SMatthias Springer           maybeYieldValue(b, loc, !loopState.empty(), result);
775558e7401SMatthias Springer         });
776558e7401SMatthias Springer 
777558e7401SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
7780f241638SMatthias Springer     return success();
7790f241638SMatthias Springer   }
7800f241638SMatthias Springer };
7810f241638SMatthias Springer 
782a088bed4SMatthias Springer } // namespace lowering_n_d
783a088bed4SMatthias Springer 
784a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
785a088bed4SMatthias Springer 
7860f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7870f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7880f241638SMatthias Springer template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)7896825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7900f241638SMatthias Springer                             int64_t i) {
7917c38fd60SJacques Pienaar   if (!xferOp.getMask())
7920f241638SMatthias Springer     return;
7930f241638SMatthias Springer 
7940f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7950f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7960f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7977c38fd60SJacques Pienaar     newXferOp.getMaskMutable().assign(xferOp.getMask());
7980f241638SMatthias Springer     return;
7990f241638SMatthias Springer   }
8000f241638SMatthias Springer 
8010f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
8020f241638SMatthias Springer     // Unpack one dimension of the mask.
8036825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
8046825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
8050f241638SMatthias Springer 
8060f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
8076825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8087c38fd60SJacques Pienaar     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
8097c38fd60SJacques Pienaar     newXferOp.getMaskMutable().assign(newMask);
8100f241638SMatthias Springer   }
8110f241638SMatthias Springer 
8120f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
8130f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
8140f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
8150f241638SMatthias Springer }
8160f241638SMatthias Springer 
8170f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
8180f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
8190f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8200f241638SMatthias Springer ///
8210f241638SMatthias Springer /// ```
8220f241638SMatthias Springer /// E.g.:
8230f241638SMatthias Springer /// ```
8240f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
8250f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
8260f241638SMatthias Springer /// ```
8270f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8280f241638SMatthias Springer /// ```
8290f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
8300f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
8310f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8320f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
8330f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
8340f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8350f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
8360f241638SMatthias Springer /// ...
8370f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
8380f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8390f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
8400f241638SMatthias Springer /// ```
8410f241638SMatthias Springer ///
8420f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
8430f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
8440f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
8452ca887deSMatthias Springer struct UnrollTransferReadConversion
8462ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
8472ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
8480f241638SMatthias Springer 
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion849700b64dcSMatthias Springer   void initialize() {
850700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
851700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
852700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
853700b64dcSMatthias Springer   }
854700b64dcSMatthias Springer 
8550f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
8560f241638SMatthias Springer   /// are inserted.
getResultVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8570f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
8580f241638SMatthias Springer                         PatternRewriter &rewriter) const {
8590f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
8607c38fd60SJacques Pienaar       return insertOp.getDest();
8616825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8626a8ba318SRiver Riddle     return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
8637c38fd60SJacques Pienaar                                             xferOp.getPadding());
8640f241638SMatthias Springer   }
8650f241638SMatthias Springer 
8660f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
8670f241638SMatthias Springer   /// vector::InsertOp, return that operation.
getInsertOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8680f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
8690f241638SMatthias Springer     if (xferOp->hasOneUse()) {
8700f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
8710f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
8720f241638SMatthias Springer         return insertOp;
8730f241638SMatthias Springer     }
8740f241638SMatthias Springer 
8750f241638SMatthias Springer     return vector::InsertOp();
8760f241638SMatthias Springer   }
8770f241638SMatthias Springer 
8780f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
8790f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8800f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
8810f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
8820f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
883*c730f9a1SKazu Hirata       for (Attribute attr : insertOp.getPosition())
8840f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
8850f241638SMatthias Springer     }
8860f241638SMatthias Springer   }
8870f241638SMatthias Springer 
8880f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8890f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8900f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
8910f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
8922ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
8930f241638SMatthias Springer       return failure();
894bd20756dSMatthias Springer     if (isTensorOp(xferOp) && !options.lowerTensors)
8958fb48979SMatthias Springer       return failure();
896f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
897f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
898f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
899f718a53dSMatthias Springer       return failure();
9000f241638SMatthias Springer 
9010f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
9020f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
9030f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
9040f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
9050f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
9060f241638SMatthias Springer                                           xferVecType.getElementType());
9070f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
9080f241638SMatthias Springer 
9090f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
9106825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
9110f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
912a54f4eaeSMogball       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
9130f241638SMatthias Springer 
9140f241638SMatthias Springer       vec = generateInBoundsCheck(
9156825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
9160f241638SMatthias Springer           /*inBoundsCase=*/
9170f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
9180f241638SMatthias Springer             // Indices for the new transfer op.
9190f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
9206825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
9210f241638SMatthias Springer 
9220f241638SMatthias Springer             // Indices for the new vector.insert op.
9230f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
9240f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
9250f241638SMatthias Springer             insertionIndices.push_back(i);
9260f241638SMatthias Springer 
9277c38fd60SJacques Pienaar             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
9286825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
9297c38fd60SJacques Pienaar                 loc, newXferVecType, xferOp.getSource(), xferIndices,
9306825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
9317c38fd60SJacques Pienaar                 xferOp.getPadding(), Value(), inBoundsAttr);
9320f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9336825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
9346825bfe2SNicolas Vasilache                                               insertionIndices);
9350f241638SMatthias Springer           },
9360f241638SMatthias Springer           /*outOfBoundsCase=*/
9370f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
9380f241638SMatthias Springer             // Loop through original (unmodified) vector.
9390f241638SMatthias Springer             return vec;
9400f241638SMatthias Springer           });
9410f241638SMatthias Springer     }
9420f241638SMatthias Springer 
9430f241638SMatthias Springer     if (insertOp) {
9440f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
9450f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
9460f241638SMatthias Springer       rewriter.eraseOp(xferOp);
9470f241638SMatthias Springer     } else {
9480f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
9490f241638SMatthias Springer     }
9500f241638SMatthias Springer 
9510f241638SMatthias Springer     return success();
9520f241638SMatthias Springer   }
9530f241638SMatthias Springer };
9540f241638SMatthias Springer 
9550f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
9560f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
9570f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
9580f241638SMatthias Springer ///
9590f241638SMatthias Springer /// ```
9600f241638SMatthias Springer /// E.g.:
9610f241638SMatthias Springer /// ```
9620f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
9630f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
9640f241638SMatthias Springer /// ```
9650f241638SMatthias Springer /// is rewritten to IR such as (simplified):
9660f241638SMatthias Springer /// ```
9670f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
9680f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
9690f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
9700f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
9710f241638SMatthias Springer /// ...
9720f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
9730f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
9740f241638SMatthias Springer /// ```
9750f241638SMatthias Springer ///
9760f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
9770f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
9780f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
9790f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
9800f241638SMatthias Springer /// recursive application of this pattern will be minimal.
9810f241638SMatthias Springer struct UnrollTransferWriteConversion
9822ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
9832ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
9840f241638SMatthias Springer 
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion985700b64dcSMatthias Springer   void initialize() {
986700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
987700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
988700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
989700b64dcSMatthias Springer   }
990700b64dcSMatthias Springer 
9910f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion9920f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
9930f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
9947c38fd60SJacques Pienaar       return extractOp.getVector();
9957c38fd60SJacques Pienaar     return xferOp.getVector();
9960f241638SMatthias Springer   }
9970f241638SMatthias Springer 
9980f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion9990f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
10007c38fd60SJacques Pienaar     if (auto *op = xferOp.getVector().getDefiningOp())
10010f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
10020f241638SMatthias Springer     return vector::ExtractOp();
10030f241638SMatthias Springer   }
10040f241638SMatthias Springer 
10050f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
10060f241638SMatthias Springer   /// indices.
getExtractionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion10070f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
10080f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
10090f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
1010*c730f9a1SKazu Hirata       for (Attribute attr : extractOp.getPosition())
10110f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
10120f241638SMatthias Springer     }
10130f241638SMatthias Springer   }
10140f241638SMatthias Springer 
10150f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
10160f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion10170f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
10180f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
10192ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
10200f241638SMatthias Springer       return failure();
1021bd20756dSMatthias Springer     if (isTensorOp(xferOp) && !options.lowerTensors)
10228fb48979SMatthias Springer       return failure();
1023f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
1024f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
1025f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
1026f718a53dSMatthias Springer       return failure();
10270f241638SMatthias Springer 
10280f241638SMatthias Springer     auto vec = getDataVector(xferOp);
10290f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
10300f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
10317c38fd60SJacques Pienaar     auto source = xferOp.getSource(); // memref or tensor to be written to.
1032bd20756dSMatthias Springer     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
10330f241638SMatthias Springer 
10340f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
10356825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
10360f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
1037a54f4eaeSMogball       Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
10380f241638SMatthias Springer 
1039bd20756dSMatthias Springer       auto updatedSource = generateInBoundsCheck(
10406825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
1041bd20756dSMatthias Springer           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1042bd20756dSMatthias Springer           /*inBoundsCase=*/
1043bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
10440f241638SMatthias Springer             // Indices for the new transfer op.
10450f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
10466825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
10470f241638SMatthias Springer 
10480f241638SMatthias Springer             // Indices for the new vector.extract op.
10490f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
10500f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
10510f241638SMatthias Springer             extractionIndices.push_back(i);
10520f241638SMatthias Springer 
10536825bfe2SNicolas Vasilache             auto extracted =
10546825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
10557c38fd60SJacques Pienaar             auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
10566825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
1057bd20756dSMatthias Springer                 loc, sourceType, extracted, source, xferIndices,
10586825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
10596825bfe2SNicolas Vasilache                 inBoundsAttr);
10600f241638SMatthias Springer 
10610f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
1062bd20756dSMatthias Springer 
1063bd20756dSMatthias Springer             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1064bd20756dSMatthias Springer           },
1065bd20756dSMatthias Springer           /*outOfBoundsCase=*/
1066bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
1067bd20756dSMatthias Springer             return isTensorOp(xferOp) ? source : Value();
10680f241638SMatthias Springer           });
1069bd20756dSMatthias Springer 
1070bd20756dSMatthias Springer       if (isTensorOp(xferOp))
1071bd20756dSMatthias Springer         source = updatedSource;
10720f241638SMatthias Springer     }
10730f241638SMatthias Springer 
1074bd20756dSMatthias Springer     if (isTensorOp(xferOp))
1075bd20756dSMatthias Springer       rewriter.replaceOp(xferOp, source);
1076bd20756dSMatthias Springer     else
10770f241638SMatthias Springer       rewriter.eraseOp(xferOp);
1078bd20756dSMatthias Springer 
10790f241638SMatthias Springer     return success();
10800f241638SMatthias Springer   }
10810f241638SMatthias Springer };
10820f241638SMatthias Springer 
1083a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
1084a088bed4SMatthias Springer 
1085a088bed4SMatthias Springer namespace lowering_1_d {
1086a088bed4SMatthias Springer 
10870f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
10880f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
10890f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
10900f241638SMatthias Springer template <typename OpTy>
10910f241638SMatthias Springer static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)10926825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
10930f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
10947c38fd60SJacques Pienaar   auto indices = xferOp.getIndices();
10957c38fd60SJacques Pienaar   auto map = xferOp.getPermutationMap();
1096c537a943SNicolas Vasilache   assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
10970f241638SMatthias Springer 
10980f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
10990f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
11000f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
11010f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
11026825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11030f241638SMatthias Springer     auto dim = expr.getPosition();
11046825bfe2SNicolas Vasilache     AffineExpr d0, d1;
11056825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
11066825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
11076825bfe2SNicolas Vasilache     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
11080f241638SMatthias Springer     return dim;
11090f241638SMatthias Springer   }
11100f241638SMatthias Springer 
11110f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
11120f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
11130f241638SMatthias Springer   return None;
11140f241638SMatthias Springer }
11150f241638SMatthias Springer 
11160f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
11170f241638SMatthias Springer /// operation.
11180f241638SMatthias Springer template <typename OpTy>
11190f241638SMatthias Springer struct Strategy1d;
11200f241638SMatthias Springer 
11210f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
11220f241638SMatthias Springer template <>
11230f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d11246825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
11250f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
11260f241638SMatthias Springer                                   ValueRange loopState) {
11270f241638SMatthias Springer     SmallVector<Value, 8> indices;
11286825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11290f241638SMatthias Springer     auto vec = loopState[0];
11300f241638SMatthias Springer 
11310f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
11320f241638SMatthias Springer     // padding value).
11330f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
11346825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
11350f241638SMatthias Springer         /*inBoundsCase=*/
11366825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
11377c38fd60SJacques Pienaar           Value val =
11387c38fd60SJacques Pienaar               b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
11397c5ecc8bSMogball           return b.create<vector::InsertElementOp>(loc, val, vec, iv);
11400f241638SMatthias Springer         },
11410f241638SMatthias Springer         /*outOfBoundsCase=*/
11420f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
11436825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
11440f241638SMatthias Springer   }
11450f241638SMatthias Springer 
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d11466825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
11470f241638SMatthias Springer     // Inititalize vector with padding value.
11486825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11496a8ba318SRiver Riddle     return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
11507c38fd60SJacques Pienaar                                      xferOp.getPadding());
11510f241638SMatthias Springer   }
11520f241638SMatthias Springer };
11530f241638SMatthias Springer 
11540f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
11550f241638SMatthias Springer template <>
11560f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d11576825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
11580f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
11590f241638SMatthias Springer                                   ValueRange /*loopState*/) {
11600f241638SMatthias Springer     SmallVector<Value, 8> indices;
11616825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11620f241638SMatthias Springer 
11630f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
11640f241638SMatthias Springer     generateInBoundsCheck(
11656825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
11666825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
11676825bfe2SNicolas Vasilache           auto val =
11687c38fd60SJacques Pienaar               b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
11697c38fd60SJacques Pienaar           b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
11700f241638SMatthias Springer         });
11716825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
11720f241638SMatthias Springer   }
11730f241638SMatthias Springer 
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d11746825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
11756825bfe2SNicolas Vasilache     return Value();
11766825bfe2SNicolas Vasilache   }
11770f241638SMatthias Springer };
11780f241638SMatthias Springer 
11790f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)11800f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
11810f241638SMatthias Springer   int64_t offset;
11820f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
11830f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
11845017b0f8SMatthias Springer   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
11850f241638SMatthias Springer }
11860f241638SMatthias Springer 
11870f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
11880f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
11890f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
11900f241638SMatthias Springer ///
11910f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
11920f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
11930f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
11940f241638SMatthias Springer ///
11950f241638SMatthias Springer /// This pattern generates IR as follows:
11960f241638SMatthias Springer ///
11970f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
11980f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
11990f241638SMatthias Springer ///    depending on OpTy.
12000f241638SMatthias Springer ///
12010f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
12020f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
12030f241638SMatthias Springer ///       to ConvertVectorToLLVM.
12040f241638SMatthias Springer ///
12050f241638SMatthias Springer /// E.g.:
12060f241638SMatthias Springer /// ```
12070f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
12080f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
12090f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
12100f241638SMatthias Springer /// ```
12110f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
12120f241638SMatthias Springer /// ```
12130f241638SMatthias Springer /// for i = 0 to 9 {
12140f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
12150f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
12160f241638SMatthias Springer /// }
12170f241638SMatthias Springer /// ```
12180f241638SMatthias Springer template <typename OpTy>
12192ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
12202ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
12210f241638SMatthias Springer 
matchAndRewrite__anon4d9edda10111::lowering_1_d::TransferOp1dConversion12220f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
12230f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
1224c537a943SNicolas Vasilache     // TODO: support 0-d corner case.
1225c537a943SNicolas Vasilache     if (xferOp.getTransferRank() == 0)
1226c537a943SNicolas Vasilache       return failure();
12277c38fd60SJacques Pienaar     auto map = xferOp.getPermutationMap();
12280f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
12290f241638SMatthias Springer 
12300f241638SMatthias Springer     if (!memRefType)
12310f241638SMatthias Springer       return failure();
12320f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
12330f241638SMatthias Springer       return failure();
12340f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
12350f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
12360f241638SMatthias Springer 
12370f241638SMatthias Springer     // Loop bounds, step, state...
12386825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
12390f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
1240a54f4eaeSMogball     auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1241a54f4eaeSMogball     auto ub =
1242a54f4eaeSMogball         rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1243a54f4eaeSMogball     auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
12446825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
12450f241638SMatthias Springer 
12460f241638SMatthias Springer     // Generate for loop.
12470f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
12480f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
12496825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
12506825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
12510f241638SMatthias Springer         });
12520f241638SMatthias Springer 
12530f241638SMatthias Springer     return success();
12540f241638SMatthias Springer   }
12550f241638SMatthias Springer };
12564ead2cf7SAlex Zinenko 
1257a088bed4SMatthias Springer } // namespace lowering_1_d
1258df63eedeSBenjamin Kramer } // namespace
1259df63eedeSBenjamin Kramer 
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)126047f175b0SRiver Riddle void mlir::populateVectorToSCFConversionPatterns(
1261dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
12620f241638SMatthias Springer   if (options.unroll) {
1263a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1264a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
12652ca887deSMatthias Springer         patterns.getContext(), options);
12660f241638SMatthias Springer   } else {
1267a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1268a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1269a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1270a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1271a088bed4SMatthias Springer         patterns.getContext(), options);
12720f241638SMatthias Springer   }
12730f241638SMatthias Springer 
12742ca887deSMatthias Springer   if (options.targetRank == 1) {
1275a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1276a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1277a088bed4SMatthias Springer         patterns.getContext(), options);
12780f241638SMatthias Springer   }
12794ead2cf7SAlex Zinenko }
12803393cc4cSNicolas Vasilache 
12815f9e0466SNicolas Vasilache namespace {
12825f9e0466SNicolas Vasilache 
12835f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
12845f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
12855f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon4d9edda11411::ConvertVectorToSCFPass12865f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
12875f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
12882ca887deSMatthias Springer     this->targetRank = options.targetRank;
1289fb7ec1f1SMatthias Springer     this->lowerPermutationMaps = options.lowerPermutationMaps;
1290558e7401SMatthias Springer     this->lowerTensors = options.lowerTensors;
12915f9e0466SNicolas Vasilache   }
12925f9e0466SNicolas Vasilache 
runOnOperation__anon4d9edda11411::ConvertVectorToSCFPass129341574554SRiver Riddle   void runOnOperation() override {
12942ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1295fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1296fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1297fb7ec1f1SMatthias Springer     options.lowerPermutationMaps = lowerPermutationMaps;
1298558e7401SMatthias Springer     options.lowerTensors = lowerTensors;
1299fb7ec1f1SMatthias Springer 
1300fb7ec1f1SMatthias Springer     // Lower permutation maps first.
1301fb7ec1f1SMatthias Springer     if (lowerPermutationMaps) {
130247f175b0SRiver Riddle       RewritePatternSet lowerTransferPatterns(&getContext());
1303fb7ec1f1SMatthias Springer       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1304fb7ec1f1SMatthias Springer           lowerTransferPatterns);
130541574554SRiver Riddle       (void)applyPatternsAndFoldGreedily(getOperation(),
1306fb7ec1f1SMatthias Springer                                          std::move(lowerTransferPatterns));
1307fb7ec1f1SMatthias Springer     }
13082ca887deSMatthias Springer 
130947f175b0SRiver Riddle     RewritePatternSet patterns(&getContext());
13102ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
131141574554SRiver Riddle     (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
13125f9e0466SNicolas Vasilache   }
13135f9e0466SNicolas Vasilache };
13145f9e0466SNicolas Vasilache 
13155f9e0466SNicolas Vasilache } // namespace
13165f9e0466SNicolas Vasilache 
13175f9e0466SNicolas Vasilache std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)13185f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
13195f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
13205f9e0466SNicolas Vasilache }
1321