10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko 
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko 
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache 
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
196825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/Utils.h"
2066f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
216825bfe2SNicolas Vasilache #include "mlir/Dialect/SCF/SCF.h"
224ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h"
237c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h"
244ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
256825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
265f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
27b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
285f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
294ead2cf7SAlex Zinenko 
304ead2cf7SAlex Zinenko using namespace mlir;
314ead2cf7SAlex Zinenko using vector::TransferReadOp;
324ead2cf7SAlex Zinenko using vector::TransferWriteOp;
334ead2cf7SAlex Zinenko 
34350dadaaSBenjamin Kramer namespace {
350f241638SMatthias Springer 
360f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
370f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
380f241638SMatthias Springer 
392ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
402ca887deSMatthias Springer /// VectorTransferToSCFOptions.
412ca887deSMatthias Springer template <typename OpTy>
422ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
432ca887deSMatthias Springer   explicit VectorToSCFPattern(MLIRContext *context,
442ca887deSMatthias Springer                               VectorTransferToSCFOptions opt)
452ca887deSMatthias Springer       : OpRewritePattern<OpTy>(context), options(opt) {}
462ca887deSMatthias Springer 
472ca887deSMatthias Springer   VectorTransferToSCFOptions options;
482ca887deSMatthias Springer };
490f241638SMatthias Springer 
500f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
510f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
520f241638SMatthias Springer /// A return value of None indicates a broadcast.
530f241638SMatthias Springer template <typename OpTy>
540f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
550f241638SMatthias Springer   auto map = xferOp.permutation_map();
560f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
570f241638SMatthias Springer     return expr.getPosition();
587c3c5b11SNicolas Vasilache   }
590f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
600f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
610f241638SMatthias Springer   return None;
620f241638SMatthias Springer }
630f241638SMatthias Springer 
640f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
650f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
660f241638SMatthias Springer /// omitted.
670f241638SMatthias Springer template <typename OpTy>
686825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
690f241638SMatthias Springer   auto map = xferOp.permutation_map();
700f241638SMatthias Springer   return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
716825bfe2SNicolas Vasilache                         b.getContext());
720f241638SMatthias Springer }
730f241638SMatthias Springer 
740f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
750f241638SMatthias Springer ///
760f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
770f241638SMatthias Springer ///       --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
780f241638SMatthias Springer ///                                 ^^^^^^
790f241638SMatthias Springer ///              `iv` is the iteration variable of the (new) surrounding loop.
800f241638SMatthias Springer template <typename OpTy>
816825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
820f241638SMatthias Springer                            SmallVector<Value, 8> &indices) {
830f241638SMatthias Springer   typename OpTy::Adaptor adaptor(xferOp);
840f241638SMatthias Springer   // Corresponding memref dim of the vector dim that is unpacked.
850f241638SMatthias Springer   auto dim = unpackedDim(xferOp);
860f241638SMatthias Springer   auto prevIndices = adaptor.indices();
870f241638SMatthias Springer   indices.append(prevIndices.begin(), prevIndices.end());
880f241638SMatthias Springer 
896825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
900f241638SMatthias Springer   bool isBroadcast = !dim.hasValue();
910f241638SMatthias Springer   if (!isBroadcast) {
926825bfe2SNicolas Vasilache     AffineExpr d0, d1;
936825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
946825bfe2SNicolas Vasilache     Value offset = adaptor.indices()[dim.getValue()];
956825bfe2SNicolas Vasilache     indices[dim.getValue()] =
966825bfe2SNicolas Vasilache         makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
970f241638SMatthias Springer   }
980f241638SMatthias Springer }
990f241638SMatthias Springer 
1006825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1010f241638SMatthias Springer                             Value value) {
1020f241638SMatthias Springer   if (hasRetVal) {
103558e7401SMatthias Springer     assert(value && "Expected non-empty value");
1046825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, value);
1050f241638SMatthias Springer   } else {
1066825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
1070f241638SMatthias Springer   }
1080f241638SMatthias Springer }
1090f241638SMatthias Springer 
1100f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1110f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1120f241638SMatthias Springer /// * xferOp does not have a mask.
1130f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1140f241638SMatthias Springer ///   computed and attached to the new transfer op in the pattern.)
1150f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1160f241638SMatthias Springer template <typename OpTy>
1176825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1180f241638SMatthias Springer   if (!xferOp.mask())
1190f241638SMatthias Springer     return Value();
1200f241638SMatthias Springer   if (xferOp.getMaskType().getRank() != 1)
1210f241638SMatthias Springer     return Value();
1220f241638SMatthias Springer   if (xferOp.isBroadcastDim(0))
1230f241638SMatthias Springer     return Value();
1240f241638SMatthias Springer 
1256825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1266825bfe2SNicolas Vasilache   Value ivI32 =
1276825bfe2SNicolas Vasilache       b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
1286825bfe2SNicolas Vasilache   return b.create<vector::ExtractElementOp>(loc, xferOp.mask(), ivI32);
1290f241638SMatthias Springer }
1300f241638SMatthias Springer 
1310f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1320f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1330f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1340f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1350f241638SMatthias Springer /// ```
1360f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1370f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?xf32>
1380f241638SMatthias Springer /// ```
1390f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1400f241638SMatthias Springer /// ```
1410f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1420f241638SMatthias Springer /// if (%a + iv < %d) {
1430f241638SMatthias Springer ///   (in-bounds case)
1440f241638SMatthias Springer /// } else {
1450f241638SMatthias Springer ///   (out-of-bounds case)
1460f241638SMatthias Springer /// }
1470f241638SMatthias Springer /// ```
1480f241638SMatthias Springer ///
1490f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1500f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1510f241638SMatthias Springer ///
1520f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1530f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1540f241638SMatthias Springer /// `resultTypes`.
1550f241638SMatthias Springer template <typename OpTy>
1560f241638SMatthias Springer static Value generateInBoundsCheck(
1576825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1580f241638SMatthias Springer     TypeRange resultTypes,
1590f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1600f241638SMatthias Springer     function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1610f241638SMatthias Springer   bool hasRetVal = !resultTypes.empty();
1620f241638SMatthias Springer   Value cond; // Condition to be built...
1630f241638SMatthias Springer 
1640f241638SMatthias Springer   // Condition check 1: Access in-bounds?
1650f241638SMatthias Springer   bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts.
1666825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
1676825bfe2SNicolas Vasilache   ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1680f241638SMatthias Springer   if (!xferOp.isDimInBounds(0) && !isBroadcast) {
169*2c115eccSMatthias Springer     Value memrefDim = vector::createOrFoldDimOp(b, loc, xferOp.source(), *dim);
1706825bfe2SNicolas Vasilache     AffineExpr d0, d1;
1716825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
1726825bfe2SNicolas Vasilache     Value base = xferOp.indices()[dim.getValue()];
1736825bfe2SNicolas Vasilache     Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
1746825bfe2SNicolas Vasilache     cond = lb.create<CmpIOp>(CmpIPredicate::sgt, memrefDim, memrefIdx);
1750f241638SMatthias Springer   }
1760f241638SMatthias Springer 
1770f241638SMatthias Springer   // Condition check 2: Masked in?
1786825bfe2SNicolas Vasilache   if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
1796825bfe2SNicolas Vasilache     if (cond)
1806825bfe2SNicolas Vasilache       cond = lb.create<AndOp>(cond, maskCond);
1816825bfe2SNicolas Vasilache     else
1820f241638SMatthias Springer       cond = maskCond;
1830f241638SMatthias Springer   }
1840f241638SMatthias Springer 
1850f241638SMatthias Springer   // If the condition is non-empty, generate an SCF::IfOp.
1860f241638SMatthias Springer   if (cond) {
1876825bfe2SNicolas Vasilache     auto check = lb.create<scf::IfOp>(
1886825bfe2SNicolas Vasilache         resultTypes, cond,
1890f241638SMatthias Springer         /*thenBuilder=*/
1906825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1916825bfe2SNicolas Vasilache           maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
192cadb7ccfSAlex Zinenko         },
1930f241638SMatthias Springer         /*elseBuilder=*/
1946825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
1950f241638SMatthias Springer           if (outOfBoundsCase) {
1966825bfe2SNicolas Vasilache             maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
1977c3c5b11SNicolas Vasilache           } else {
1986825bfe2SNicolas Vasilache             b.create<scf::YieldOp>(loc);
1997c3c5b11SNicolas Vasilache           }
2007c3c5b11SNicolas Vasilache         });
2017c3c5b11SNicolas Vasilache 
2020f241638SMatthias Springer     return hasRetVal ? check.getResult(0) : Value();
2034ead2cf7SAlex Zinenko   }
2044ead2cf7SAlex Zinenko 
2050f241638SMatthias Springer   // Condition is empty, no need for an SCF::IfOp.
2066825bfe2SNicolas Vasilache   return inBoundsCase(b, loc);
2070f241638SMatthias Springer }
2080f241638SMatthias Springer 
2090f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2100f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2110f241638SMatthias Springer template <typename OpTy>
2120f241638SMatthias Springer static void generateInBoundsCheck(
2136825bfe2SNicolas Vasilache     OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2140f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> inBoundsCase,
2150f241638SMatthias Springer     function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2160f241638SMatthias Springer   generateInBoundsCheck(
2176825bfe2SNicolas Vasilache       b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2180f241638SMatthias Springer       /*inBoundsCase=*/
2196825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2206825bfe2SNicolas Vasilache         inBoundsCase(b, loc);
2210f241638SMatthias Springer         return Value();
2220f241638SMatthias Springer       },
2230f241638SMatthias Springer       /*outOfBoundsCase=*/
2246825bfe2SNicolas Vasilache       [&](OpBuilder &b, Location loc) {
2250f241638SMatthias Springer         if (outOfBoundsCase)
2266825bfe2SNicolas Vasilache           outOfBoundsCase(b, loc);
2270f241638SMatthias Springer         return Value();
2280f241638SMatthias Springer       });
2290f241638SMatthias Springer }
2300f241638SMatthias Springer 
2310f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
2326825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2330f241638SMatthias Springer   if (!attr)
2340f241638SMatthias Springer     return attr;
2356825bfe2SNicolas Vasilache   return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2360f241638SMatthias Springer }
2370f241638SMatthias Springer 
2380f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2390f241638SMatthias Springer /// rank.
2400f241638SMatthias Springer template <typename OpTy>
2416825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2422ca887deSMatthias Springer                                 unsigned targetRank) {
2432ca887deSMatthias Springer   if (newXferOp.getVectorType().getRank() > targetRank)
2446825bfe2SNicolas Vasilache     newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2450f241638SMatthias Springer }
2460f241638SMatthias Springer 
247558e7401SMatthias Springer /// Return true if this transfer op operates on a source tensor.
248558e7401SMatthias Springer template <typename OpTy>
249558e7401SMatthias Springer static bool isTensorOp(OpTy xferOp) {
250558e7401SMatthias Springer   if (xferOp.getShapedType().template isa<RankedTensorType>()) {
251558e7401SMatthias Springer     if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
252558e7401SMatthias Springer       // TransferWriteOps on tensors have a result.
253558e7401SMatthias Springer       assert(xferOp->getNumResults() > 0);
254558e7401SMatthias Springer     }
255558e7401SMatthias Springer     return true;
256558e7401SMatthias Springer   }
257558e7401SMatthias Springer   return false;
258558e7401SMatthias Springer }
259558e7401SMatthias Springer 
260a088bed4SMatthias Springer namespace lowering_n_d {
261a088bed4SMatthias Springer 
262a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
263a088bed4SMatthias Springer struct BufferAllocs {
264a088bed4SMatthias Springer   Value dataBuffer;
265a088bed4SMatthias Springer   Value maskBuffer;
266a088bed4SMatthias Springer };
267a088bed4SMatthias Springer 
268a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
269a088bed4SMatthias Springer /// TODO: Parallelism and threadlocal considerations.
270a088bed4SMatthias Springer template <typename OpTy>
2716825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
2726825bfe2SNicolas Vasilache   Location loc = xferOp.getLoc();
273a088bed4SMatthias Springer   OpBuilder::InsertionGuard guard(b);
274a088bed4SMatthias Springer   Operation *scope =
275a088bed4SMatthias Springer       xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>();
276a088bed4SMatthias Springer   assert(scope && "Expected op to be inside automatic allocation scope");
277a088bed4SMatthias Springer   b.setInsertionPointToStart(&scope->getRegion(0).front());
278a088bed4SMatthias Springer 
279a088bed4SMatthias Springer   BufferAllocs result;
280a088bed4SMatthias Springer   auto bufferType = MemRefType::get({}, xferOp.getVectorType());
2816825bfe2SNicolas Vasilache   result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
282a088bed4SMatthias Springer 
283a088bed4SMatthias Springer   if (xferOp.mask()) {
284a088bed4SMatthias Springer     auto maskType = MemRefType::get({}, xferOp.mask().getType());
2856825bfe2SNicolas Vasilache     auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
286fb7ec1f1SMatthias Springer     b.setInsertionPoint(xferOp);
2876825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, xferOp.mask(), maskBuffer);
2886825bfe2SNicolas Vasilache     result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
289a088bed4SMatthias Springer   }
290a088bed4SMatthias Springer 
291a088bed4SMatthias Springer   return result;
292a088bed4SMatthias Springer }
293a088bed4SMatthias Springer 
294a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
295a088bed4SMatthias Springer /// the VectorType into the MemRefType.
296a088bed4SMatthias Springer ///
297a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
298a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
299a088bed4SMatthias Springer   auto vectorType = type.getElementType().dyn_cast<VectorType>();
300a088bed4SMatthias Springer   auto memrefShape = type.getShape();
301a088bed4SMatthias Springer   SmallVector<int64_t, 8> newMemrefShape;
302a088bed4SMatthias Springer   newMemrefShape.append(memrefShape.begin(), memrefShape.end());
303a088bed4SMatthias Springer   newMemrefShape.push_back(vectorType.getDimSize(0));
304a088bed4SMatthias Springer   return MemRefType::get(newMemrefShape,
305a088bed4SMatthias Springer                          VectorType::get(vectorType.getShape().drop_front(),
306a088bed4SMatthias Springer                                          vectorType.getElementType()));
307a088bed4SMatthias Springer }
308a088bed4SMatthias Springer 
3090f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
3100f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
3110f241638SMatthias Springer template <typename OpTy>
3120f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
3130f241638SMatthias Springer   assert(xferOp.mask() && "Expected that transfer op has mask");
3140f241638SMatthias Springer   auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>();
3150f241638SMatthias Springer   assert(loadOp && "Expected transfer op mask produced by LoadOp");
3160f241638SMatthias Springer   return loadOp.getMemRef();
3170f241638SMatthias Springer }
3180f241638SMatthias Springer 
3190f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3200f241638SMatthias Springer template <typename OpTy>
3210f241638SMatthias Springer struct Strategy;
3220f241638SMatthias Springer 
3230f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3244ead2cf7SAlex Zinenko template <>
3250f241638SMatthias Springer struct Strategy<TransferReadOp> {
3260f241638SMatthias Springer   /// Find the StoreOp that is used for writing the current TransferReadOp's
3270f241638SMatthias Springer   /// result to the temporary buffer allocation.
3280f241638SMatthias Springer   static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3290f241638SMatthias Springer     assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3300f241638SMatthias Springer     auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3310f241638SMatthias Springer     assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3320f241638SMatthias Springer     return storeOp;
3337c3c5b11SNicolas Vasilache   }
3344ead2cf7SAlex Zinenko 
3350f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferReadOps are
3360f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
3370f241638SMatthias Springer   /// of the buffer allocation:
3380f241638SMatthias Springer   /// ```
3390f241638SMatthias Springer   /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3400f241638SMatthias Springer   /// memref.store %vec, %buf[...] ...
3410f241638SMatthias Springer   /// ```
3420f241638SMatthias Springer   static Value getBuffer(TransferReadOp xferOp) {
3430f241638SMatthias Springer     return getStoreOp(xferOp).getMemRef();
3441870e787SNicolas Vasilache   }
3450f241638SMatthias Springer 
3460f241638SMatthias Springer   /// Retrieve the indices of the current StoreOp that stores into the buffer.
3470f241638SMatthias Springer   static void getBufferIndices(TransferReadOp xferOp,
3480f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
3490f241638SMatthias Springer     auto storeOp = getStoreOp(xferOp);
3500f241638SMatthias Springer     auto prevIndices = memref::StoreOpAdaptor(storeOp).indices();
3510f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
3520f241638SMatthias Springer   }
3530f241638SMatthias Springer 
3540f241638SMatthias Springer   /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3550f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
3560f241638SMatthias Springer   ///
3570f241638SMatthias Springer   /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3580f241638SMatthias Springer   ///    variable `iv`.
3590f241638SMatthias Springer   /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3600f241638SMatthias Springer   ///
3610f241638SMatthias Springer   /// E.g.:
3620f241638SMatthias Springer   /// ```
3630f241638SMatthias Springer   /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3640f241638SMatthias Springer   ///     : memref<?x?x?xf32>, vector<4x3xf32>
3650f241638SMatthias Springer   /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3660f241638SMatthias Springer   /// ```
3670f241638SMatthias Springer   /// Is rewritten to:
3680f241638SMatthias Springer   /// ```
3690f241638SMatthias Springer   /// %casted = vector.type_cast %buf
3700f241638SMatthias Springer   ///     : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3710f241638SMatthias Springer   /// for %j = 0 to 4 {
3720f241638SMatthias Springer   ///   %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3730f241638SMatthias Springer   ///       : memref<?x?x?xf32>, vector<3xf32>
3740f241638SMatthias Springer   ///   memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3750f241638SMatthias Springer   /// }
3760f241638SMatthias Springer   /// ```
3770f241638SMatthias Springer   ///
3780f241638SMatthias Springer   /// Note: The loop and type cast are generated in TransferOpConversion.
3790f241638SMatthias Springer   ///       The original TransferReadOp and store op are deleted in `cleanup`.
3800f241638SMatthias Springer   /// Note: The `mask` operand is set in TransferOpConversion.
3816825bfe2SNicolas Vasilache   static TransferReadOp rewriteOp(OpBuilder &b,
3822ca887deSMatthias Springer                                   VectorTransferToSCFOptions options,
383558e7401SMatthias Springer                                   TransferReadOp xferOp, Value buffer, Value iv,
384558e7401SMatthias Springer                                   ValueRange /*loopState*/) {
3850f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
3860f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
3870f241638SMatthias Springer     storeIndices.push_back(iv);
3880f241638SMatthias Springer 
3890f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
3906825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
3910f241638SMatthias Springer 
3926825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
3930f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
3940f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
3956825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
3966825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferReadOp>(
3976825bfe2SNicolas Vasilache         loc, vecType, xferOp.source(), xferIndices,
3986825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), xferOp.padding(),
3996825bfe2SNicolas Vasilache         Value(), inBoundsAttr);
4000f241638SMatthias Springer 
4016825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4020f241638SMatthias Springer 
4036825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, newXferOp.vector(), buffer, storeIndices);
4046825bfe2SNicolas Vasilache     return newXferOp;
4050f241638SMatthias Springer   }
4060f241638SMatthias Springer 
4070f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
4080f241638SMatthias Springer   /// padding value to the temporary buffer.
409558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
410558e7401SMatthias Springer                                     Value buffer, Value iv,
411558e7401SMatthias Springer                                     ValueRange /*loopState*/) {
4120f241638SMatthias Springer     SmallVector<Value, 8> storeIndices;
4130f241638SMatthias Springer     getBufferIndices(xferOp, storeIndices);
4140f241638SMatthias Springer     storeIndices.push_back(iv);
4150f241638SMatthias Springer 
4166825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4170f241638SMatthias Springer     auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4180f241638SMatthias Springer     auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4196825bfe2SNicolas Vasilache     auto vec = b.create<SplatOp>(loc, vecType, xferOp.padding());
4206825bfe2SNicolas Vasilache     b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
421558e7401SMatthias Springer 
422558e7401SMatthias Springer     return Value();
4230f241638SMatthias Springer   }
4240f241638SMatthias Springer 
4250f241638SMatthias Springer   /// Cleanup after rewriting the op.
426558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
427558e7401SMatthias Springer                       scf::ForOp /*forOp*/) {
4280f241638SMatthias Springer     rewriter.eraseOp(getStoreOp(xferOp));
4290f241638SMatthias Springer     rewriter.eraseOp(xferOp);
4300f241638SMatthias Springer   }
431558e7401SMatthias Springer 
432558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
433558e7401SMatthias Springer   static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
4344ead2cf7SAlex Zinenko };
4357c3c5b11SNicolas Vasilache 
4360f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4370f241638SMatthias Springer template <>
4380f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4390f241638SMatthias Springer   /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4400f241638SMatthias Springer   /// used like this, where %buf is either the buffer allocation or a type cast
4410f241638SMatthias Springer   /// of the buffer allocation:
4420f241638SMatthias Springer   /// ```
4430f241638SMatthias Springer   /// %vec = memref.load %buf[...] ...
4440f241638SMatthias Springer   /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4450f241638SMatthias Springer   /// ```
4460f241638SMatthias Springer   static Value getBuffer(TransferWriteOp xferOp) {
4470f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4480f241638SMatthias Springer     assert(loadOp && "Expected transfer op vector produced by LoadOp");
4490f241638SMatthias Springer     return loadOp.getMemRef();
4507c3c5b11SNicolas Vasilache   }
4514ead2cf7SAlex Zinenko 
4520f241638SMatthias Springer   /// Retrieve the indices of the current LoadOp that loads from the buffer.
4530f241638SMatthias Springer   static void getBufferIndices(TransferWriteOp xferOp,
4540f241638SMatthias Springer                                SmallVector<Value, 8> &indices) {
4550f241638SMatthias Springer     auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>();
4560f241638SMatthias Springer     auto prevIndices = memref::LoadOpAdaptor(loadOp).indices();
4570f241638SMatthias Springer     indices.append(prevIndices.begin(), prevIndices.end());
4580f241638SMatthias Springer   }
4590f241638SMatthias Springer 
4600f241638SMatthias Springer   /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4610f241638SMatthias Springer   /// accesses on the to-be-unpacked dimension.
4620f241638SMatthias Springer   ///
4630f241638SMatthias Springer   /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4640f241638SMatthias Springer   ///    using the loop iteration variable `iv`.
4650f241638SMatthias Springer   /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4660f241638SMatthias Springer   ///    to memory.
4670f241638SMatthias Springer   ///
4680f241638SMatthias Springer   /// Note: For more details, see comments on Strategy<TransferReadOp>.
4696825bfe2SNicolas Vasilache   static TransferWriteOp rewriteOp(OpBuilder &b,
4702ca887deSMatthias Springer                                    VectorTransferToSCFOptions options,
4712ca887deSMatthias Springer                                    TransferWriteOp xferOp, Value buffer,
472558e7401SMatthias Springer                                    Value iv, ValueRange loopState) {
4730f241638SMatthias Springer     SmallVector<Value, 8> loadIndices;
4740f241638SMatthias Springer     getBufferIndices(xferOp, loadIndices);
4750f241638SMatthias Springer     loadIndices.push_back(iv);
4760f241638SMatthias Springer 
4770f241638SMatthias Springer     SmallVector<Value, 8> xferIndices;
4786825bfe2SNicolas Vasilache     getXferIndices(b, xferOp, iv, xferIndices);
4790f241638SMatthias Springer 
4806825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
4816825bfe2SNicolas Vasilache     auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
4826825bfe2SNicolas Vasilache     auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
483558e7401SMatthias Springer     auto source = loopState.empty() ? xferOp.source() : loopState[0];
484558e7401SMatthias Springer     Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
4856825bfe2SNicolas Vasilache     auto newXferOp = b.create<vector::TransferWriteOp>(
486558e7401SMatthias Springer         loc, type, vec, source, xferIndices,
4876825bfe2SNicolas Vasilache         AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4880f241638SMatthias Springer         inBoundsAttr);
4890f241638SMatthias Springer 
4906825bfe2SNicolas Vasilache     maybeApplyPassLabel(b, newXferOp, options.targetRank);
4910f241638SMatthias Springer 
4926825bfe2SNicolas Vasilache     return newXferOp;
4930f241638SMatthias Springer   }
4940f241638SMatthias Springer 
4950f241638SMatthias Springer   /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
496558e7401SMatthias Springer   static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
497558e7401SMatthias Springer                                     Value buffer, Value iv,
498558e7401SMatthias Springer                                     ValueRange loopState) {
499558e7401SMatthias Springer     return isTensorOp(xferOp) ? loopState[0] : Value();
500558e7401SMatthias Springer   }
5010f241638SMatthias Springer 
5020f241638SMatthias Springer   /// Cleanup after rewriting the op.
503558e7401SMatthias Springer   static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
504558e7401SMatthias Springer                       scf::ForOp forOp) {
505558e7401SMatthias Springer     if (isTensorOp(xferOp)) {
506558e7401SMatthias Springer       assert(forOp->getNumResults() == 1 && "Expected one for loop result");
507558e7401SMatthias Springer       rewriter.replaceOp(xferOp, forOp->getResult(0));
508558e7401SMatthias Springer     } else {
5090f241638SMatthias Springer       rewriter.eraseOp(xferOp);
5100f241638SMatthias Springer     }
511558e7401SMatthias Springer   }
512558e7401SMatthias Springer 
513558e7401SMatthias Springer   /// Return the initial loop state for the generated scf.for loop.
514558e7401SMatthias Springer   static Value initialLoopState(TransferWriteOp xferOp) {
515558e7401SMatthias Springer     return isTensorOp(xferOp) ? xferOp.source() : Value();
516558e7401SMatthias Springer   }
5170f241638SMatthias Springer };
5180f241638SMatthias Springer 
5190f241638SMatthias Springer template <typename OpTy>
520fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
521fb7ec1f1SMatthias Springer                                  VectorTransferToSCFOptions options) {
5220f241638SMatthias Springer   if (xferOp->hasAttr(kPassLabel))
5230f241638SMatthias Springer     return failure();
524fb7ec1f1SMatthias Springer   if (xferOp.getVectorType().getRank() <= options.targetRank)
5250f241638SMatthias Springer     return failure();
526558e7401SMatthias Springer   if (isTensorOp(xferOp) && !options.lowerTensors)
5278fb48979SMatthias Springer     return failure();
528f718a53dSMatthias Springer   // Transfer ops that modify the element type are not supported atm.
529f718a53dSMatthias Springer   if (xferOp.getVectorType().getElementType() !=
530f718a53dSMatthias Springer       xferOp.getShapedType().getElementType())
531f718a53dSMatthias Springer     return failure();
5320f241638SMatthias Springer   return success();
5330f241638SMatthias Springer }
5340f241638SMatthias Springer 
5350f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
5360f241638SMatthias Springer ///
5370f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5380f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
5390f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
5400f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
5410f241638SMatthias Springer ///    original TransferReadOp with this load.
5420f241638SMatthias Springer ///
5430f241638SMatthias Springer /// E.g.:
5440f241638SMatthias Springer /// ```
5450f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5460f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5470f241638SMatthias Springer /// ```
5480f241638SMatthias Springer /// is rewritten to:
5490f241638SMatthias Springer /// ```
5500f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5510f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5520f241638SMatthias Springer ///     { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5530f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5540f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5550f241638SMatthias Springer /// ```
5560f241638SMatthias Springer ///
5570f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5582ca887deSMatthias Springer struct PrepareTransferReadConversion
5592ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
5602ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5610f241638SMatthias Springer 
5620f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
5630f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
564fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
5650f241638SMatthias Springer       return failure();
5660f241638SMatthias Springer 
5676825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
5680f241638SMatthias Springer     auto *newXfer = rewriter.clone(*xferOp.getOperation());
5690f241638SMatthias Springer     newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5700f241638SMatthias Springer     if (xferOp.mask()) {
5710f241638SMatthias Springer       dyn_cast<TransferReadOp>(newXfer).maskMutable().assign(
5720f241638SMatthias Springer           buffers.maskBuffer);
5730f241638SMatthias Springer     }
5740f241638SMatthias Springer 
5756825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
5766825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
5776825bfe2SNicolas Vasilache                                      buffers.dataBuffer);
5780f241638SMatthias Springer     rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5794ead2cf7SAlex Zinenko 
5804ead2cf7SAlex Zinenko     return success();
5814ead2cf7SAlex Zinenko   }
5820f241638SMatthias Springer };
5830f241638SMatthias Springer 
5840f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5850f241638SMatthias Springer ///
5860f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5870f241638SMatthias Springer /// 2. Store the vector into the buffer.
5880f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5890f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
5900f241638SMatthias Springer ///    marking it eligible for progressive lowering via TransferOpConversion.
5910f241638SMatthias Springer ///
5920f241638SMatthias Springer /// E.g.:
5930f241638SMatthias Springer /// ```
5940f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
5950f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
5960f241638SMatthias Springer /// ```
5970f241638SMatthias Springer /// is rewritten to:
5980f241638SMatthias Springer /// ```
5990f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
6000f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
6010f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
6020f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
6030f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
6040f241638SMatthias Springer /// ```
6050f241638SMatthias Springer ///
6060f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
6070f241638SMatthias Springer struct PrepareTransferWriteConversion
6082ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
6092ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
6100f241638SMatthias Springer 
6110f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
6120f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
613fb7ec1f1SMatthias Springer     if (checkPrepareXferOp(xferOp, options).failed())
6140f241638SMatthias Springer       return failure();
6150f241638SMatthias Springer 
6166825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
6176825bfe2SNicolas Vasilache     auto buffers = allocBuffers(rewriter, xferOp);
6186825bfe2SNicolas Vasilache     rewriter.create<memref::StoreOp>(loc, xferOp.vector(), buffers.dataBuffer);
6196825bfe2SNicolas Vasilache     auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
6200f241638SMatthias Springer     rewriter.updateRootInPlace(xferOp, [&]() {
6210f241638SMatthias Springer       xferOp.vectorMutable().assign(loadedVec);
6220f241638SMatthias Springer       xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
6230f241638SMatthias Springer     });
6240f241638SMatthias Springer 
6250f241638SMatthias Springer     if (xferOp.mask()) {
6260f241638SMatthias Springer       rewriter.updateRootInPlace(
6270f241638SMatthias Springer           xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); });
6280f241638SMatthias Springer     }
6290f241638SMatthias Springer 
6300f241638SMatthias Springer     return success();
6310f241638SMatthias Springer   }
6320f241638SMatthias Springer };
6330f241638SMatthias Springer 
6340f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
6350f241638SMatthias Springer ///
6360f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
6370f241638SMatthias Springer ///    to that new type. E.g.:
6380f241638SMatthias Springer ///    ```
6390f241638SMatthias Springer ///    %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
6400f241638SMatthias Springer ///    vector.transfer_write %vec ...
6410f241638SMatthias Springer ///    ```
6420f241638SMatthias Springer ///    The following cast is generated:
6430f241638SMatthias Springer ///    ```
6440f241638SMatthias Springer ///    %casted = vector.type_cast %0
6450f241638SMatthias Springer ///        : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6460f241638SMatthias Springer ///    ```
6470f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6480f241638SMatthias Springer ///    corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6490f241638SMatthias Springer ///    out-of-bounds, generate an if-check and handle both cases separately.
6500f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
651558e7401SMatthias Springer ///
652558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
653558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated
654558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.:
655558e7401SMatthias Springer /// ```
656558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 {
657558e7401SMatthias Springer ///   %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
658558e7401SMatthias Springer ///   %1 = vector.transfer_write %0, %source[...]
659558e7401SMatthias Springer ///       : vector<4x3xf32>, tensor<5x4x3xf32>
660558e7401SMatthias Springer ///   scf.yield %1 : tensor<5x4x3xf32>
661558e7401SMatthias Springer /// }
662558e7401SMatthias Springer /// ```
6630f241638SMatthias Springer template <typename OpTy>
6642ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6652ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6660f241638SMatthias Springer 
667700b64dcSMatthias Springer   void initialize() {
668700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
669700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
670700b64dcSMatthias Springer     this->setHasBoundedRewriteRecursion();
671700b64dcSMatthias Springer   }
672700b64dcSMatthias Springer 
6730f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
6740f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
6750f241638SMatthias Springer     if (!xferOp->hasAttr(kPassLabel))
6760f241638SMatthias Springer       return failure();
6770f241638SMatthias Springer 
6780f241638SMatthias Springer     // Find and cast data buffer. How the buffer can be found depends on OpTy.
6796825bfe2SNicolas Vasilache     ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6800f241638SMatthias Springer     auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6810f241638SMatthias Springer     auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6820f241638SMatthias Springer     auto castedDataType = unpackOneDim(dataBufferType);
6836825bfe2SNicolas Vasilache     auto castedDataBuffer =
6846825bfe2SNicolas Vasilache         locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6850f241638SMatthias Springer 
6860f241638SMatthias Springer     // If the xferOp has a mask: Find and cast mask buffer.
6870f241638SMatthias Springer     Value castedMaskBuffer;
6880f241638SMatthias Springer     if (xferOp.mask()) {
6890f241638SMatthias Springer       auto maskBuffer = getMaskBuffer(xferOp);
6900f241638SMatthias Springer       auto maskBufferType =
6910f241638SMatthias Springer           maskBuffer.getType().template dyn_cast<MemRefType>();
6920f241638SMatthias Springer       if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
6930f241638SMatthias Springer         // Do not unpack a dimension of the mask, if:
6940f241638SMatthias Springer         // * To-be-unpacked transfer op dimension is a broadcast.
6950f241638SMatthias Springer         // * Mask is 1D, i.e., the mask cannot be further unpacked.
6960f241638SMatthias Springer         //   (That means that all remaining dimensions of the transfer op must
6970f241638SMatthias Springer         //   be broadcasted.)
6980f241638SMatthias Springer         castedMaskBuffer = maskBuffer;
6990f241638SMatthias Springer       } else {
7000f241638SMatthias Springer         auto castedMaskType = unpackOneDim(maskBufferType);
7016825bfe2SNicolas Vasilache         castedMaskBuffer =
7026825bfe2SNicolas Vasilache             locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
7030f241638SMatthias Springer       }
7040f241638SMatthias Springer     }
7050f241638SMatthias Springer 
7060f241638SMatthias Springer     // Loop bounds and step.
7076825bfe2SNicolas Vasilache     auto lb = locB.create<ConstantIndexOp>(0);
7086825bfe2SNicolas Vasilache     auto ub = locB.create<ConstantIndexOp>(
7096825bfe2SNicolas Vasilache         castedDataType.getDimSize(castedDataType.getRank() - 1));
7106825bfe2SNicolas Vasilache     auto step = locB.create<ConstantIndexOp>(1);
711558e7401SMatthias Springer     // TransferWriteOps that operate on tensors return the modified tensor and
712558e7401SMatthias Springer     // require a loop state.
713558e7401SMatthias Springer     auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
7140f241638SMatthias Springer 
7150f241638SMatthias Springer     // Generate for loop.
716558e7401SMatthias Springer     auto result = locB.create<scf::ForOp>(
717558e7401SMatthias Springer         lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
718558e7401SMatthias Springer         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
719558e7401SMatthias Springer           Type stateType = loopState.empty() ? Type() : loopState[0].getType();
720558e7401SMatthias Springer 
721558e7401SMatthias Springer           auto result = generateInBoundsCheck(
7226825bfe2SNicolas Vasilache               b, xferOp, iv, unpackedDim(xferOp),
723558e7401SMatthias Springer               stateType ? TypeRange(stateType) : TypeRange(),
7240f241638SMatthias Springer               /*inBoundsCase=*/
7256825bfe2SNicolas Vasilache               [&](OpBuilder &b, Location loc) {
7260f241638SMatthias Springer                 // Create new transfer op.
7272ca887deSMatthias Springer                 OpTy newXfer = Strategy<OpTy>::rewriteOp(
728558e7401SMatthias Springer                     b, this->options, xferOp, castedDataBuffer, iv, loopState);
7290f241638SMatthias Springer 
7300f241638SMatthias Springer                 // If old transfer op has a mask: Set mask on new transfer op.
7310f241638SMatthias Springer                 // Special case: If the mask of the old transfer op is 1D and
7320f241638SMatthias Springer                 // the
7330f241638SMatthias Springer                 //               unpacked dim is not a broadcast, no mask is
7340f241638SMatthias Springer                 //               needed on the new transfer op.
7350f241638SMatthias Springer                 if (xferOp.mask() && (xferOp.isBroadcastDim(0) ||
7360f241638SMatthias Springer                                       xferOp.getMaskType().getRank() > 1)) {
7370f241638SMatthias Springer                   OpBuilder::InsertionGuard guard(b);
7380f241638SMatthias Springer                   b.setInsertionPoint(newXfer); // Insert load before newXfer.
7390f241638SMatthias Springer 
7400f241638SMatthias Springer                   SmallVector<Value, 8> loadIndices;
7410f241638SMatthias Springer                   Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
7420f241638SMatthias Springer                   // In case of broadcast: Use same indices to load from memref
7430f241638SMatthias Springer                   // as before.
7440f241638SMatthias Springer                   if (!xferOp.isBroadcastDim(0))
7450f241638SMatthias Springer                     loadIndices.push_back(iv);
7460f241638SMatthias Springer 
7476825bfe2SNicolas Vasilache                   auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
7486825bfe2SNicolas Vasilache                                                        loadIndices);
7490f241638SMatthias Springer                   rewriter.updateRootInPlace(
7500f241638SMatthias Springer                       newXfer, [&]() { newXfer.maskMutable().assign(mask); });
7510f241638SMatthias Springer                 }
752558e7401SMatthias Springer 
753558e7401SMatthias Springer                 return loopState.empty() ? Value() : newXfer->getResult(0);
7540f241638SMatthias Springer               },
7550f241638SMatthias Springer               /*outOfBoundsCase=*/
7560f241638SMatthias Springer               [&](OpBuilder &b, Location /*loc*/) {
757558e7401SMatthias Springer                 return Strategy<OpTy>::handleOutOfBoundsDim(
758558e7401SMatthias Springer                     b, xferOp, castedDataBuffer, iv, loopState);
7590f241638SMatthias Springer               });
7600f241638SMatthias Springer 
761558e7401SMatthias Springer           maybeYieldValue(b, loc, !loopState.empty(), result);
762558e7401SMatthias Springer         });
763558e7401SMatthias Springer 
764558e7401SMatthias Springer     Strategy<OpTy>::cleanup(rewriter, xferOp, result);
7650f241638SMatthias Springer     return success();
7660f241638SMatthias Springer   }
7670f241638SMatthias Springer };
7680f241638SMatthias Springer 
769a088bed4SMatthias Springer } // namespace lowering_n_d
770a088bed4SMatthias Springer 
771a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
772a088bed4SMatthias Springer 
7730f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7740f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7750f241638SMatthias Springer template <typename OpTy>
7766825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7770f241638SMatthias Springer                             int64_t i) {
7780f241638SMatthias Springer   if (!xferOp.mask())
7790f241638SMatthias Springer     return;
7800f241638SMatthias Springer 
7810f241638SMatthias Springer   if (xferOp.isBroadcastDim(0)) {
7820f241638SMatthias Springer     // To-be-unpacked dimension is a broadcast, which does not have a
7830f241638SMatthias Springer     // corresponding mask dimension. Mask attribute remains unchanged.
7840f241638SMatthias Springer     newXferOp.maskMutable().assign(xferOp.mask());
7850f241638SMatthias Springer     return;
7860f241638SMatthias Springer   }
7870f241638SMatthias Springer 
7880f241638SMatthias Springer   if (xferOp.getMaskType().getRank() > 1) {
7890f241638SMatthias Springer     // Unpack one dimension of the mask.
7906825bfe2SNicolas Vasilache     OpBuilder::InsertionGuard guard(b);
7916825bfe2SNicolas Vasilache     b.setInsertionPoint(newXferOp); // Insert load before newXfer.
7920f241638SMatthias Springer 
7930f241638SMatthias Springer     llvm::SmallVector<int64_t, 1> indices({i});
7946825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
7956825bfe2SNicolas Vasilache     auto newMask = b.create<vector::ExtractOp>(loc, xferOp.mask(), indices);
7960f241638SMatthias Springer     newXferOp.maskMutable().assign(newMask);
7970f241638SMatthias Springer   }
7980f241638SMatthias Springer 
7990f241638SMatthias Springer   // If we end up here: The mask of the old transfer op is 1D and the unpacked
8000f241638SMatthias Springer   // dim is not a broadcast, so no mask is needed on the new transfer op.
8010f241638SMatthias Springer   // `generateInBoundsCheck` will have evaluated the mask already.
8020f241638SMatthias Springer }
8030f241638SMatthias Springer 
8040f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
8050f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
8060f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8070f241638SMatthias Springer ///
8080f241638SMatthias Springer /// ```
8090f241638SMatthias Springer /// E.g.:
8100f241638SMatthias Springer /// ```
8110f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
8120f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<5x4xf32>
8130f241638SMatthias Springer /// ```
8140f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8150f241638SMatthias Springer /// ```
8160f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
8170f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
8180f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8190f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
8200f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
8210f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8220f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
8230f241638SMatthias Springer /// ...
8240f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
8250f241638SMatthias Springer ///     : memref<?x?x?xf32>, vector<4xf32>
8260f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
8270f241638SMatthias Springer /// ```
8280f241638SMatthias Springer ///
8290f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
8300f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
8310f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
8322ca887deSMatthias Springer struct UnrollTransferReadConversion
8332ca887deSMatthias Springer     : public VectorToSCFPattern<TransferReadOp> {
8342ca887deSMatthias Springer   using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
8350f241638SMatthias Springer 
836700b64dcSMatthias Springer   void initialize() {
837700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
838700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
839700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
840700b64dcSMatthias Springer   }
841700b64dcSMatthias Springer 
8420f241638SMatthias Springer   /// Return the vector into which the newly created TransferReadOp results
8430f241638SMatthias Springer   /// are inserted.
8440f241638SMatthias Springer   Value getResultVector(TransferReadOp xferOp,
8450f241638SMatthias Springer                         PatternRewriter &rewriter) const {
8460f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp))
8470f241638SMatthias Springer       return insertOp.dest();
8486825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8496825bfe2SNicolas Vasilache     return rewriter.create<SplatOp>(loc, xferOp.getVectorType(),
8506825bfe2SNicolas Vasilache                                     xferOp.padding());
8510f241638SMatthias Springer   }
8520f241638SMatthias Springer 
8530f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
8540f241638SMatthias Springer   /// vector::InsertOp, return that operation.
8550f241638SMatthias Springer   vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
8560f241638SMatthias Springer     if (xferOp->hasOneUse()) {
8570f241638SMatthias Springer       Operation *xferOpUser = *xferOp->getUsers().begin();
8580f241638SMatthias Springer       if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
8590f241638SMatthias Springer         return insertOp;
8600f241638SMatthias Springer     }
8610f241638SMatthias Springer 
8620f241638SMatthias Springer     return vector::InsertOp();
8630f241638SMatthias Springer   }
8640f241638SMatthias Springer 
8650f241638SMatthias Springer   /// If the result of the TransferReadOp has exactly one user, which is a
8660f241638SMatthias Springer   /// vector::InsertOp, return that operation's indices.
8670f241638SMatthias Springer   void getInsertionIndices(TransferReadOp xferOp,
8680f241638SMatthias Springer                            SmallVector<int64_t, 8> &indices) const {
8690f241638SMatthias Springer     if (auto insertOp = getInsertOp(xferOp)) {
8700f241638SMatthias Springer       llvm::for_each(insertOp.position(), [&](Attribute attr) {
8710f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
8720f241638SMatthias Springer       });
8730f241638SMatthias Springer     }
8740f241638SMatthias Springer   }
8750f241638SMatthias Springer 
8760f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8770f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
8780f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferReadOp xferOp,
8790f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
8802ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
8810f241638SMatthias Springer       return failure();
882bd20756dSMatthias Springer     if (isTensorOp(xferOp) && !options.lowerTensors)
8838fb48979SMatthias Springer       return failure();
884f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
885f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
886f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
887f718a53dSMatthias Springer       return failure();
8880f241638SMatthias Springer 
8890f241638SMatthias Springer     auto insertOp = getInsertOp(xferOp);
8900f241638SMatthias Springer     auto vec = getResultVector(xferOp, rewriter);
8910f241638SMatthias Springer     auto vecType = vec.getType().dyn_cast<VectorType>();
8920f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
8930f241638SMatthias Springer     auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
8940f241638SMatthias Springer                                           xferVecType.getElementType());
8950f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
8960f241638SMatthias Springer 
8970f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
8986825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
8990f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
9006825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
9010f241638SMatthias Springer 
9020f241638SMatthias Springer       vec = generateInBoundsCheck(
9036825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
9040f241638SMatthias Springer           /*inBoundsCase=*/
9050f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
9060f241638SMatthias Springer             // Indices for the new transfer op.
9070f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
9086825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
9090f241638SMatthias Springer 
9100f241638SMatthias Springer             // Indices for the new vector.insert op.
9110f241638SMatthias Springer             SmallVector<int64_t, 8> insertionIndices;
9120f241638SMatthias Springer             getInsertionIndices(xferOp, insertionIndices);
9130f241638SMatthias Springer             insertionIndices.push_back(i);
9140f241638SMatthias Springer 
9150f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
9166825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferReadOp>(
9176825bfe2SNicolas Vasilache                 loc, newXferVecType, xferOp.source(), xferIndices,
9186825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
9196825bfe2SNicolas Vasilache                 xferOp.padding(), Value(), inBoundsAttr);
9200f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
9216825bfe2SNicolas Vasilache             return b.create<vector::InsertOp>(loc, newXferOp, vec,
9226825bfe2SNicolas Vasilache                                               insertionIndices);
9230f241638SMatthias Springer           },
9240f241638SMatthias Springer           /*outOfBoundsCase=*/
9250f241638SMatthias Springer           [&](OpBuilder &b, Location loc) {
9260f241638SMatthias Springer             // Loop through original (unmodified) vector.
9270f241638SMatthias Springer             return vec;
9280f241638SMatthias Springer           });
9290f241638SMatthias Springer     }
9300f241638SMatthias Springer 
9310f241638SMatthias Springer     if (insertOp) {
9320f241638SMatthias Springer       // Rewrite single user of the old TransferReadOp, which was an InsertOp.
9330f241638SMatthias Springer       rewriter.replaceOp(insertOp, vec);
9340f241638SMatthias Springer       rewriter.eraseOp(xferOp);
9350f241638SMatthias Springer     } else {
9360f241638SMatthias Springer       rewriter.replaceOp(xferOp, vec);
9370f241638SMatthias Springer     }
9380f241638SMatthias Springer 
9390f241638SMatthias Springer     return success();
9400f241638SMatthias Springer   }
9410f241638SMatthias Springer };
9420f241638SMatthias Springer 
9430f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
9440f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
9450f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
9460f241638SMatthias Springer ///
9470f241638SMatthias Springer /// ```
9480f241638SMatthias Springer /// E.g.:
9490f241638SMatthias Springer /// ```
9500f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
9510f241638SMatthias Springer ///     : vector<5x4xf32>, memref<?x?x?xf32>
9520f241638SMatthias Springer /// ```
9530f241638SMatthias Springer /// is rewritten to IR such as (simplified):
9540f241638SMatthias Springer /// ```
9550f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
9560f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
9570f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
9580f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
9590f241638SMatthias Springer /// ...
9600f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
9610f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
9620f241638SMatthias Springer /// ```
9630f241638SMatthias Springer ///
9640f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
9650f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
9660f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
9670f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
9680f241638SMatthias Springer /// recursive application of this pattern will be minimal.
9690f241638SMatthias Springer struct UnrollTransferWriteConversion
9702ca887deSMatthias Springer     : public VectorToSCFPattern<TransferWriteOp> {
9712ca887deSMatthias Springer   using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
9720f241638SMatthias Springer 
973700b64dcSMatthias Springer   void initialize() {
974700b64dcSMatthias Springer     // This pattern recursively unpacks one dimension at a time. The recursion
975700b64dcSMatthias Springer     // bounded as the rank is strictly decreasing.
976700b64dcSMatthias Springer     setHasBoundedRewriteRecursion();
977700b64dcSMatthias Springer   }
978700b64dcSMatthias Springer 
9790f241638SMatthias Springer   /// Return the vector from which newly generated ExtracOps will extract.
9800f241638SMatthias Springer   Value getDataVector(TransferWriteOp xferOp) const {
9810f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp))
9820f241638SMatthias Springer       return extractOp.vector();
9830f241638SMatthias Springer     return xferOp.vector();
9840f241638SMatthias Springer   }
9850f241638SMatthias Springer 
9860f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return it.
9870f241638SMatthias Springer   vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
9880f241638SMatthias Springer     if (auto *op = xferOp.vector().getDefiningOp())
9890f241638SMatthias Springer       return dyn_cast<vector::ExtractOp>(op);
9900f241638SMatthias Springer     return vector::ExtractOp();
9910f241638SMatthias Springer   }
9920f241638SMatthias Springer 
9930f241638SMatthias Springer   /// If the input of the given TransferWriteOp is an ExtractOp, return its
9940f241638SMatthias Springer   /// indices.
9950f241638SMatthias Springer   void getExtractionIndices(TransferWriteOp xferOp,
9960f241638SMatthias Springer                             SmallVector<int64_t, 8> &indices) const {
9970f241638SMatthias Springer     if (auto extractOp = getExtractOp(xferOp)) {
9980f241638SMatthias Springer       llvm::for_each(extractOp.position(), [&](Attribute attr) {
9990f241638SMatthias Springer         indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
10000f241638SMatthias Springer       });
10010f241638SMatthias Springer     }
10020f241638SMatthias Springer   }
10030f241638SMatthias Springer 
10040f241638SMatthias Springer   /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
10050f241638SMatthias Springer   /// accesses, and broadcasts and transposes in permutation maps.
10060f241638SMatthias Springer   LogicalResult matchAndRewrite(TransferWriteOp xferOp,
10070f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
10082ca887deSMatthias Springer     if (xferOp.getVectorType().getRank() <= options.targetRank)
10090f241638SMatthias Springer       return failure();
1010bd20756dSMatthias Springer     if (isTensorOp(xferOp) && !options.lowerTensors)
10118fb48979SMatthias Springer       return failure();
1012f718a53dSMatthias Springer     // Transfer ops that modify the element type are not supported atm.
1013f718a53dSMatthias Springer     if (xferOp.getVectorType().getElementType() !=
1014f718a53dSMatthias Springer         xferOp.getShapedType().getElementType())
1015f718a53dSMatthias Springer       return failure();
10160f241638SMatthias Springer 
10170f241638SMatthias Springer     auto vec = getDataVector(xferOp);
10180f241638SMatthias Springer     auto xferVecType = xferOp.getVectorType();
10190f241638SMatthias Springer     int64_t dimSize = xferVecType.getShape()[0];
1020bd20756dSMatthias Springer     auto source = xferOp.source(); // memref or tensor to be written to.
1021bd20756dSMatthias Springer     auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
10220f241638SMatthias Springer 
10230f241638SMatthias Springer     // Generate fully unrolled loop of transfer ops.
10246825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
10250f241638SMatthias Springer     for (int64_t i = 0; i < dimSize; ++i) {
10266825bfe2SNicolas Vasilache       Value iv = rewriter.create<ConstantIndexOp>(loc, i);
10270f241638SMatthias Springer 
1028bd20756dSMatthias Springer       auto updatedSource = generateInBoundsCheck(
10296825bfe2SNicolas Vasilache           rewriter, xferOp, iv, unpackedDim(xferOp),
1030bd20756dSMatthias Springer           isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1031bd20756dSMatthias Springer           /*inBoundsCase=*/
1032bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
10330f241638SMatthias Springer             // Indices for the new transfer op.
10340f241638SMatthias Springer             SmallVector<Value, 8> xferIndices;
10356825bfe2SNicolas Vasilache             getXferIndices(b, xferOp, iv, xferIndices);
10360f241638SMatthias Springer 
10370f241638SMatthias Springer             // Indices for the new vector.extract op.
10380f241638SMatthias Springer             SmallVector<int64_t, 8> extractionIndices;
10390f241638SMatthias Springer             getExtractionIndices(xferOp, extractionIndices);
10400f241638SMatthias Springer             extractionIndices.push_back(i);
10410f241638SMatthias Springer 
10426825bfe2SNicolas Vasilache             auto extracted =
10436825bfe2SNicolas Vasilache                 b.create<vector::ExtractOp>(loc, vec, extractionIndices);
10440f241638SMatthias Springer             auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr());
10456825bfe2SNicolas Vasilache             auto newXferOp = b.create<vector::TransferWriteOp>(
1046bd20756dSMatthias Springer                 loc, sourceType, extracted, source, xferIndices,
10476825bfe2SNicolas Vasilache                 AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
10486825bfe2SNicolas Vasilache                 inBoundsAttr);
10490f241638SMatthias Springer 
10500f241638SMatthias Springer             maybeAssignMask(b, xferOp, newXferOp, i);
1051bd20756dSMatthias Springer 
1052bd20756dSMatthias Springer             return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1053bd20756dSMatthias Springer           },
1054bd20756dSMatthias Springer           /*outOfBoundsCase=*/
1055bd20756dSMatthias Springer           [&](OpBuilder &b, Location loc) {
1056bd20756dSMatthias Springer             return isTensorOp(xferOp) ? source : Value();
10570f241638SMatthias Springer           });
1058bd20756dSMatthias Springer 
1059bd20756dSMatthias Springer       if (isTensorOp(xferOp))
1060bd20756dSMatthias Springer         source = updatedSource;
10610f241638SMatthias Springer     }
10620f241638SMatthias Springer 
1063bd20756dSMatthias Springer     if (isTensorOp(xferOp))
1064bd20756dSMatthias Springer       rewriter.replaceOp(xferOp, source);
1065bd20756dSMatthias Springer     else
10660f241638SMatthias Springer       rewriter.eraseOp(xferOp);
1067bd20756dSMatthias Springer 
10680f241638SMatthias Springer     return success();
10690f241638SMatthias Springer   }
10700f241638SMatthias Springer };
10710f241638SMatthias Springer 
1072a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
1073a088bed4SMatthias Springer 
1074a088bed4SMatthias Springer namespace lowering_1_d {
1075a088bed4SMatthias Springer 
10760f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
10770f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
10780f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
10790f241638SMatthias Springer template <typename OpTy>
10800f241638SMatthias Springer static Optional<int64_t>
10816825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
10820f241638SMatthias Springer                    SmallVector<Value, 8> &memrefIndices) {
10830f241638SMatthias Springer   auto indices = xferOp.indices();
10840f241638SMatthias Springer   auto map = xferOp.permutation_map();
10850f241638SMatthias Springer 
10860f241638SMatthias Springer   memrefIndices.append(indices.begin(), indices.end());
10870f241638SMatthias Springer   assert(map.getNumResults() == 1 &&
10880f241638SMatthias Springer          "Expected 1 permutation map result for 1D transfer");
10890f241638SMatthias Springer   if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
10906825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
10910f241638SMatthias Springer     auto dim = expr.getPosition();
10926825bfe2SNicolas Vasilache     AffineExpr d0, d1;
10936825bfe2SNicolas Vasilache     bindDims(xferOp.getContext(), d0, d1);
10946825bfe2SNicolas Vasilache     Value offset = memrefIndices[dim];
10956825bfe2SNicolas Vasilache     memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
10960f241638SMatthias Springer     return dim;
10970f241638SMatthias Springer   }
10980f241638SMatthias Springer 
10990f241638SMatthias Springer   assert(xferOp.isBroadcastDim(0) &&
11000f241638SMatthias Springer          "Expected AffineDimExpr or AffineConstantExpr");
11010f241638SMatthias Springer   return None;
11020f241638SMatthias Springer }
11030f241638SMatthias Springer 
11040f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
11050f241638SMatthias Springer /// operation.
11060f241638SMatthias Springer template <typename OpTy>
11070f241638SMatthias Springer struct Strategy1d;
11080f241638SMatthias Springer 
11090f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
11100f241638SMatthias Springer template <>
11110f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
11126825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
11130f241638SMatthias Springer                                   TransferReadOp xferOp, Value iv,
11140f241638SMatthias Springer                                   ValueRange loopState) {
11150f241638SMatthias Springer     SmallVector<Value, 8> indices;
11166825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11176825bfe2SNicolas Vasilache     Value ivI32 =
11186825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
11190f241638SMatthias Springer     auto vec = loopState[0];
11200f241638SMatthias Springer 
11210f241638SMatthias Springer     // In case of out-of-bounds access, leave `vec` as is (was initialized with
11220f241638SMatthias Springer     // padding value).
11230f241638SMatthias Springer     auto nextVec = generateInBoundsCheck(
11246825bfe2SNicolas Vasilache         b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
11250f241638SMatthias Springer         /*inBoundsCase=*/
11266825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc) {
11276825bfe2SNicolas Vasilache           Value val = b.create<memref::LoadOp>(loc, xferOp.source(), indices);
11286825bfe2SNicolas Vasilache           return b.create<vector::InsertElementOp>(loc, val, vec, ivI32);
11290f241638SMatthias Springer         },
11300f241638SMatthias Springer         /*outOfBoundsCase=*/
11310f241638SMatthias Springer         [&](OpBuilder & /*b*/, Location loc) { return vec; });
11326825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc, nextVec);
11330f241638SMatthias Springer   }
11340f241638SMatthias Springer 
11356825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
11360f241638SMatthias Springer     // Inititalize vector with padding value.
11376825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
11386825bfe2SNicolas Vasilache     return b.create<SplatOp>(loc, xferOp.getVectorType(), xferOp.padding());
11390f241638SMatthias Springer   }
11400f241638SMatthias Springer };
11410f241638SMatthias Springer 
11420f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
11430f241638SMatthias Springer template <>
11440f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
11456825bfe2SNicolas Vasilache   static void generateForLoopBody(OpBuilder &b, Location loc,
11460f241638SMatthias Springer                                   TransferWriteOp xferOp, Value iv,
11470f241638SMatthias Springer                                   ValueRange /*loopState*/) {
11480f241638SMatthias Springer     SmallVector<Value, 8> indices;
11496825bfe2SNicolas Vasilache     auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11506825bfe2SNicolas Vasilache     Value ivI32 =
11516825bfe2SNicolas Vasilache         b.create<IndexCastOp>(loc, IntegerType::get(b.getContext(), 32), iv);
11520f241638SMatthias Springer 
11530f241638SMatthias Springer     // Nothing to do in case of out-of-bounds access.
11540f241638SMatthias Springer     generateInBoundsCheck(
11556825bfe2SNicolas Vasilache         b, xferOp, iv, dim,
11566825bfe2SNicolas Vasilache         /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
11576825bfe2SNicolas Vasilache           auto val =
11586825bfe2SNicolas Vasilache               b.create<vector::ExtractElementOp>(loc, xferOp.vector(), ivI32);
11596825bfe2SNicolas Vasilache           b.create<memref::StoreOp>(loc, val, xferOp.source(), indices);
11600f241638SMatthias Springer         });
11616825bfe2SNicolas Vasilache     b.create<scf::YieldOp>(loc);
11620f241638SMatthias Springer   }
11630f241638SMatthias Springer 
11646825bfe2SNicolas Vasilache   static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
11656825bfe2SNicolas Vasilache     return Value();
11666825bfe2SNicolas Vasilache   }
11670f241638SMatthias Springer };
11680f241638SMatthias Springer 
11690f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
11700f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
11710f241638SMatthias Springer   int64_t offset;
11720f241638SMatthias Springer   SmallVector<int64_t, 4> strides;
11730f241638SMatthias Springer   auto successStrides = getStridesAndOffset(type, strides, offset);
11745017b0f8SMatthias Springer   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
11750f241638SMatthias Springer }
11760f241638SMatthias Springer 
11770f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
11780f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
11790f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
11800f241638SMatthias Springer ///
11810f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
11820f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
11830f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
11840f241638SMatthias Springer ///
11850f241638SMatthias Springer /// This pattern generates IR as follows:
11860f241638SMatthias Springer ///
11870f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
11880f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
11890f241638SMatthias Springer ///    depending on OpTy.
11900f241638SMatthias Springer ///
11910f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
11920f241638SMatthias Springer ///       can be generated instead of TransferOp1dConversion. Add such a pattern
11930f241638SMatthias Springer ///       to ConvertVectorToLLVM.
11940f241638SMatthias Springer ///
11950f241638SMatthias Springer /// E.g.:
11960f241638SMatthias Springer /// ```
11970f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
11980f241638SMatthias Springer ///    {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
11990f241638SMatthias Springer ///    : vector<9xf32>, memref<?x?xf32>
12000f241638SMatthias Springer /// ```
12010f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
12020f241638SMatthias Springer /// ```
12030f241638SMatthias Springer /// for i = 0 to 9 {
12040f241638SMatthias Springer ///   %t = vector.extractelement %vec[i] : vector<9xf32>
12050f241638SMatthias Springer ///   memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
12060f241638SMatthias Springer /// }
12070f241638SMatthias Springer /// ```
12080f241638SMatthias Springer template <typename OpTy>
12092ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
12102ca887deSMatthias Springer   using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
12110f241638SMatthias Springer 
12120f241638SMatthias Springer   LogicalResult matchAndRewrite(OpTy xferOp,
12130f241638SMatthias Springer                                 PatternRewriter &rewriter) const override {
12140f241638SMatthias Springer     auto map = xferOp.permutation_map();
12150f241638SMatthias Springer     auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
12160f241638SMatthias Springer 
12170f241638SMatthias Springer     if (!memRefType)
12180f241638SMatthias Springer       return failure();
12190f241638SMatthias Springer     if (xferOp.getVectorType().getRank() != 1)
12200f241638SMatthias Springer       return failure();
12210f241638SMatthias Springer     if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
12220f241638SMatthias Springer       return failure(); // Handled by ConvertVectorToLLVM
12230f241638SMatthias Springer 
12240f241638SMatthias Springer     // Loop bounds, step, state...
12256825bfe2SNicolas Vasilache     Location loc = xferOp.getLoc();
12260f241638SMatthias Springer     auto vecType = xferOp.getVectorType();
12276825bfe2SNicolas Vasilache     auto lb = rewriter.create<ConstantIndexOp>(loc, 0);
12286825bfe2SNicolas Vasilache     auto ub = rewriter.create<ConstantIndexOp>(loc, vecType.getDimSize(0));
12296825bfe2SNicolas Vasilache     auto step = rewriter.create<ConstantIndexOp>(loc, 1);
12306825bfe2SNicolas Vasilache     auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
12310f241638SMatthias Springer 
12320f241638SMatthias Springer     // Generate for loop.
12330f241638SMatthias Springer     rewriter.replaceOpWithNewOp<scf::ForOp>(
12340f241638SMatthias Springer         xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
12356825bfe2SNicolas Vasilache         [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
12366825bfe2SNicolas Vasilache           Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
12370f241638SMatthias Springer         });
12380f241638SMatthias Springer 
12390f241638SMatthias Springer     return success();
12400f241638SMatthias Springer   }
12410f241638SMatthias Springer };
12424ead2cf7SAlex Zinenko 
1243a088bed4SMatthias Springer } // namespace lowering_1_d
1244df63eedeSBenjamin Kramer } // namespace
1245df63eedeSBenjamin Kramer 
124651d30c34SBenjamin Kramer namespace mlir {
124751d30c34SBenjamin Kramer 
12483393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns(
1249dc4e913bSChris Lattner     RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
12500f241638SMatthias Springer   if (options.unroll) {
1251a088bed4SMatthias Springer     patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1252a088bed4SMatthias Springer                  lowering_n_d_unrolled::UnrollTransferWriteConversion>(
12532ca887deSMatthias Springer         patterns.getContext(), options);
12540f241638SMatthias Springer   } else {
1255a088bed4SMatthias Springer     patterns.add<lowering_n_d::PrepareTransferReadConversion,
1256a088bed4SMatthias Springer                  lowering_n_d::PrepareTransferWriteConversion,
1257a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferReadOp>,
1258a088bed4SMatthias Springer                  lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1259a088bed4SMatthias Springer         patterns.getContext(), options);
12600f241638SMatthias Springer   }
12610f241638SMatthias Springer 
12622ca887deSMatthias Springer   if (options.targetRank == 1) {
1263a088bed4SMatthias Springer     patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1264a088bed4SMatthias Springer                  lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1265a088bed4SMatthias Springer         patterns.getContext(), options);
12660f241638SMatthias Springer   }
12674ead2cf7SAlex Zinenko }
12683393cc4cSNicolas Vasilache 
12693393cc4cSNicolas Vasilache } // namespace mlir
12703393cc4cSNicolas Vasilache 
12715f9e0466SNicolas Vasilache namespace {
12725f9e0466SNicolas Vasilache 
12735f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
12745f9e0466SNicolas Vasilache     : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
12755f9e0466SNicolas Vasilache   ConvertVectorToSCFPass() = default;
12765f9e0466SNicolas Vasilache   ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
12775f9e0466SNicolas Vasilache     this->fullUnroll = options.unroll;
12782ca887deSMatthias Springer     this->targetRank = options.targetRank;
1279fb7ec1f1SMatthias Springer     this->lowerPermutationMaps = options.lowerPermutationMaps;
1280558e7401SMatthias Springer     this->lowerTensors = options.lowerTensors;
12815f9e0466SNicolas Vasilache   }
12825f9e0466SNicolas Vasilache 
12835f9e0466SNicolas Vasilache   void runOnFunction() override {
12842ca887deSMatthias Springer     VectorTransferToSCFOptions options;
1285fb7ec1f1SMatthias Springer     options.unroll = fullUnroll;
1286fb7ec1f1SMatthias Springer     options.targetRank = targetRank;
1287fb7ec1f1SMatthias Springer     options.lowerPermutationMaps = lowerPermutationMaps;
1288558e7401SMatthias Springer     options.lowerTensors = lowerTensors;
1289fb7ec1f1SMatthias Springer 
1290fb7ec1f1SMatthias Springer     // Lower permutation maps first.
1291fb7ec1f1SMatthias Springer     if (lowerPermutationMaps) {
1292fb7ec1f1SMatthias Springer       RewritePatternSet lowerTransferPatterns(getFunction().getContext());
1293fb7ec1f1SMatthias Springer       mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1294fb7ec1f1SMatthias Springer           lowerTransferPatterns);
1295fb7ec1f1SMatthias Springer       (void)applyPatternsAndFoldGreedily(getFunction(),
1296fb7ec1f1SMatthias Springer                                          std::move(lowerTransferPatterns));
1297fb7ec1f1SMatthias Springer     }
12982ca887deSMatthias Springer 
1299dc4e913bSChris Lattner     RewritePatternSet patterns(getFunction().getContext());
13002ca887deSMatthias Springer     populateVectorToSCFConversionPatterns(patterns, options);
1301e21adfa3SRiver Riddle     (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
13025f9e0466SNicolas Vasilache   }
13035f9e0466SNicolas Vasilache };
13045f9e0466SNicolas Vasilache 
13055f9e0466SNicolas Vasilache } // namespace
13065f9e0466SNicolas Vasilache 
13075f9e0466SNicolas Vasilache std::unique_ptr<Pass>
13085f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
13095f9e0466SNicolas Vasilache   return std::make_unique<ConvertVectorToSCFPass>(options);
13105f9e0466SNicolas Vasilache }
1311