10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===//
24ead2cf7SAlex Zinenko //
34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64ead2cf7SAlex Zinenko //
74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
84ead2cf7SAlex Zinenko //
90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF.
104ead2cf7SAlex Zinenko //
114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===//
124ead2cf7SAlex Zinenko
134ead2cf7SAlex Zinenko #include <type_traits>
144ead2cf7SAlex Zinenko
154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
165f9e0466SNicolas Vasilache
175f9e0466SNicolas Vasilache #include "../PassDetail.h"
186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h"
19a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
2066f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h"
218b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h"
246825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h"
255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h"
26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h"
284ead2cf7SAlex Zinenko
294ead2cf7SAlex Zinenko using namespace mlir;
304ead2cf7SAlex Zinenko using vector::TransferReadOp;
314ead2cf7SAlex Zinenko using vector::TransferWriteOp;
324ead2cf7SAlex Zinenko
33350dadaaSBenjamin Kramer namespace {
340f241638SMatthias Springer
350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering.
360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__";
370f241638SMatthias Springer
382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to
392ca887deSMatthias Springer /// VectorTransferToSCFOptions.
402ca887deSMatthias Springer template <typename OpTy>
412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> {
VectorToSCFPattern__anon4d9edda10111::VectorToSCFPattern422ca887deSMatthias Springer explicit VectorToSCFPattern(MLIRContext *context,
432ca887deSMatthias Springer VectorTransferToSCFOptions opt)
442ca887deSMatthias Springer : OpRewritePattern<OpTy>(context), options(opt) {}
452ca887deSMatthias Springer
462ca887deSMatthias Springer VectorTransferToSCFOptions options;
472ca887deSMatthias Springer };
480f241638SMatthias Springer
490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source`
500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion.
510f241638SMatthias Springer /// A return value of None indicates a broadcast.
520f241638SMatthias Springer template <typename OpTy>
unpackedDim(OpTy xferOp)530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) {
54c537a943SNicolas Vasilache // TODO: support 0-d corner case.
55c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
567c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap();
570f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
580f241638SMatthias Springer return expr.getPosition();
597c3c5b11SNicolas Vasilache }
600f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) &&
610f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr");
620f241638SMatthias Springer return None;
630f241638SMatthias Springer }
640f241638SMatthias Springer
650f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This
660f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is
670f241638SMatthias Springer /// omitted.
680f241638SMatthias Springer template <typename OpTy>
unpackedPermutationMap(OpBuilder & b,OpTy xferOp)696825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) {
70c537a943SNicolas Vasilache // TODO: support 0-d corner case.
71c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
727c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap();
730f241638SMatthias Springer return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(),
746825bfe2SNicolas Vasilache b.getContext());
750f241638SMatthias Springer }
760f241638SMatthias Springer
770f241638SMatthias Springer /// Calculate the indices for the new vector transfer op.
780f241638SMatthias Springer ///
790f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ...
800f241638SMatthias Springer /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32>
810f241638SMatthias Springer /// ^^^^^^
820f241638SMatthias Springer /// `iv` is the iteration variable of the (new) surrounding loop.
830f241638SMatthias Springer template <typename OpTy>
getXferIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & indices)846825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv,
850f241638SMatthias Springer SmallVector<Value, 8> &indices) {
860f241638SMatthias Springer typename OpTy::Adaptor adaptor(xferOp);
870f241638SMatthias Springer // Corresponding memref dim of the vector dim that is unpacked.
880f241638SMatthias Springer auto dim = unpackedDim(xferOp);
897c38fd60SJacques Pienaar auto prevIndices = adaptor.getIndices();
900f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end());
910f241638SMatthias Springer
926825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
93491d2701SKazu Hirata bool isBroadcast = !dim.has_value();
940f241638SMatthias Springer if (!isBroadcast) {
956825bfe2SNicolas Vasilache AffineExpr d0, d1;
966825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1);
97c27d8152SKazu Hirata Value offset = adaptor.getIndices()[dim.value()];
98c27d8152SKazu Hirata indices[dim.value()] =
996825bfe2SNicolas Vasilache makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
1000f241638SMatthias Springer }
1010f241638SMatthias Springer }
1020f241638SMatthias Springer
maybeYieldValue(OpBuilder & b,Location loc,bool hasRetVal,Value value)1036825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal,
1040f241638SMatthias Springer Value value) {
1050f241638SMatthias Springer if (hasRetVal) {
106558e7401SMatthias Springer assert(value && "Expected non-empty value");
1076825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, value);
1080f241638SMatthias Springer } else {
1096825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc);
1100f241638SMatthias Springer }
1110f241638SMatthias Springer }
1120f241638SMatthias Springer
1130f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask
1140f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances:
1150f241638SMatthias Springer /// * xferOp does not have a mask.
1160f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is
1170f241638SMatthias Springer /// computed and attached to the new transfer op in the pattern.)
1180f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast.
1190f241638SMatthias Springer template <typename OpTy>
generateMaskCheck(OpBuilder & b,OpTy xferOp,Value iv)1206825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) {
1217c38fd60SJacques Pienaar if (!xferOp.getMask())
1220f241638SMatthias Springer return Value();
1230f241638SMatthias Springer if (xferOp.getMaskType().getRank() != 1)
1240f241638SMatthias Springer return Value();
1250f241638SMatthias Springer if (xferOp.isBroadcastDim(0))
1260f241638SMatthias Springer return Value();
1270f241638SMatthias Springer
1286825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
1297c38fd60SJacques Pienaar return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv);
1300f241638SMatthias Springer }
1310f241638SMatthias Springer
1320f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion.
1330f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the
1340f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`.
1350f241638SMatthias Springer /// E.g., when unpacking dimension 0 from:
1360f241638SMatthias Springer /// ```
1370f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst
1380f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?xf32>
1390f241638SMatthias Springer /// ```
1400f241638SMatthias Springer /// An if check similar to this will be generated inside the loop:
1410f241638SMatthias Springer /// ```
1420f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32>
1430f241638SMatthias Springer /// if (%a + iv < %d) {
1440f241638SMatthias Springer /// (in-bounds case)
1450f241638SMatthias Springer /// } else {
1460f241638SMatthias Springer /// (out-of-bounds case)
1470f241638SMatthias Springer /// }
1480f241638SMatthias Springer /// ```
1490f241638SMatthias Springer ///
1500f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex
1510f241638SMatthias Springer /// check also accounts for potentially masked out elements.
1520f241638SMatthias Springer ///
1530f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or
1540f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in
1550f241638SMatthias Springer /// `resultTypes`.
1560f241638SMatthias Springer template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,TypeRange resultTypes,function_ref<Value (OpBuilder &,Location)> inBoundsCase,function_ref<Value (OpBuilder &,Location)> outOfBoundsCase=nullptr)1570f241638SMatthias Springer static Value generateInBoundsCheck(
1586825bfe2SNicolas Vasilache OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
1590f241638SMatthias Springer TypeRange resultTypes,
1600f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> inBoundsCase,
1610f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
1620f241638SMatthias Springer bool hasRetVal = !resultTypes.empty();
1630f241638SMatthias Springer Value cond; // Condition to be built...
1640f241638SMatthias Springer
1650f241638SMatthias Springer // Condition check 1: Access in-bounds?
1660916d96dSKazu Hirata bool isBroadcast = !dim; // No in-bounds check for broadcasts.
1676825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
1686825bfe2SNicolas Vasilache ImplicitLocOpBuilder lb(xferOp.getLoc(), b);
1690f241638SMatthias Springer if (!xferOp.isDimInBounds(0) && !isBroadcast) {
1707c38fd60SJacques Pienaar Value memrefDim =
1717c38fd60SJacques Pienaar vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim);
1726825bfe2SNicolas Vasilache AffineExpr d0, d1;
1736825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1);
1746d5fc1e3SKazu Hirata Value base = xferOp.getIndices()[*dim];
1756825bfe2SNicolas Vasilache Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv});
176a54f4eaeSMogball cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim,
177a54f4eaeSMogball memrefIdx);
1780f241638SMatthias Springer }
1790f241638SMatthias Springer
1800f241638SMatthias Springer // Condition check 2: Masked in?
1816825bfe2SNicolas Vasilache if (auto maskCond = generateMaskCheck(b, xferOp, iv)) {
1826825bfe2SNicolas Vasilache if (cond)
183a54f4eaeSMogball cond = lb.create<arith::AndIOp>(cond, maskCond);
1846825bfe2SNicolas Vasilache else
1850f241638SMatthias Springer cond = maskCond;
1860f241638SMatthias Springer }
1870f241638SMatthias Springer
1880f241638SMatthias Springer // If the condition is non-empty, generate an SCF::IfOp.
1890f241638SMatthias Springer if (cond) {
1906825bfe2SNicolas Vasilache auto check = lb.create<scf::IfOp>(
1916825bfe2SNicolas Vasilache resultTypes, cond,
1920f241638SMatthias Springer /*thenBuilder=*/
1936825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
1946825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc));
195cadb7ccfSAlex Zinenko },
1960f241638SMatthias Springer /*elseBuilder=*/
1976825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
1980f241638SMatthias Springer if (outOfBoundsCase) {
1996825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc));
2007c3c5b11SNicolas Vasilache } else {
2016825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc);
2027c3c5b11SNicolas Vasilache }
2037c3c5b11SNicolas Vasilache });
2047c3c5b11SNicolas Vasilache
2050f241638SMatthias Springer return hasRetVal ? check.getResult(0) : Value();
2064ead2cf7SAlex Zinenko }
2074ead2cf7SAlex Zinenko
2080f241638SMatthias Springer // Condition is empty, no need for an SCF::IfOp.
2096825bfe2SNicolas Vasilache return inBoundsCase(b, loc);
2100f241638SMatthias Springer }
2110f241638SMatthias Springer
2120f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have
2130f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value.
2140f241638SMatthias Springer template <typename OpTy>
generateInBoundsCheck(OpBuilder & b,OpTy xferOp,Value iv,Optional<int64_t> dim,function_ref<void (OpBuilder &,Location)> inBoundsCase,function_ref<void (OpBuilder &,Location)> outOfBoundsCase=nullptr)2150f241638SMatthias Springer static void generateInBoundsCheck(
2166825bfe2SNicolas Vasilache OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim,
2170f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> inBoundsCase,
2180f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) {
2190f241638SMatthias Springer generateInBoundsCheck(
2206825bfe2SNicolas Vasilache b, xferOp, iv, dim, /*resultTypes=*/TypeRange(),
2210f241638SMatthias Springer /*inBoundsCase=*/
2226825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
2236825bfe2SNicolas Vasilache inBoundsCase(b, loc);
2240f241638SMatthias Springer return Value();
2250f241638SMatthias Springer },
2260f241638SMatthias Springer /*outOfBoundsCase=*/
2276825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
2280f241638SMatthias Springer if (outOfBoundsCase)
2296825bfe2SNicolas Vasilache outOfBoundsCase(b, loc);
2300f241638SMatthias Springer return Value();
2310f241638SMatthias Springer });
2320f241638SMatthias Springer }
2330f241638SMatthias Springer
2340f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped.
dropFirstElem(OpBuilder & b,ArrayAttr attr)2356825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) {
2360f241638SMatthias Springer if (!attr)
2370f241638SMatthias Springer return attr;
2386825bfe2SNicolas Vasilache return ArrayAttr::get(b.getContext(), attr.getValue().drop_front());
2390f241638SMatthias Springer }
2400f241638SMatthias Springer
2410f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target
2420f241638SMatthias Springer /// rank.
2430f241638SMatthias Springer template <typename OpTy>
maybeApplyPassLabel(OpBuilder & b,OpTy newXferOp,unsigned targetRank)2446825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp,
2452ca887deSMatthias Springer unsigned targetRank) {
2462ca887deSMatthias Springer if (newXferOp.getVectorType().getRank() > targetRank)
2476825bfe2SNicolas Vasilache newXferOp->setAttr(kPassLabel, b.getUnitAttr());
2480f241638SMatthias Springer }
2490f241638SMatthias Springer
250558e7401SMatthias Springer /// Return true if this transfer op operates on a source tensor.
251558e7401SMatthias Springer template <typename OpTy>
isTensorOp(OpTy xferOp)252558e7401SMatthias Springer static bool isTensorOp(OpTy xferOp) {
253558e7401SMatthias Springer if (xferOp.getShapedType().template isa<RankedTensorType>()) {
254558e7401SMatthias Springer if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) {
255558e7401SMatthias Springer // TransferWriteOps on tensors have a result.
256558e7401SMatthias Springer assert(xferOp->getNumResults() > 0);
257558e7401SMatthias Springer }
258558e7401SMatthias Springer return true;
259558e7401SMatthias Springer }
260558e7401SMatthias Springer return false;
261558e7401SMatthias Springer }
262558e7401SMatthias Springer
263a088bed4SMatthias Springer namespace lowering_n_d {
264a088bed4SMatthias Springer
265a088bed4SMatthias Springer /// Helper data structure for data and mask buffers.
266a088bed4SMatthias Springer struct BufferAllocs {
267a088bed4SMatthias Springer Value dataBuffer;
268a088bed4SMatthias Springer Value maskBuffer;
269a088bed4SMatthias Springer };
270a088bed4SMatthias Springer
2713c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait.
getAutomaticAllocationScope(Operation * op)2723c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) {
2733c3810e7SNicolas Vasilache Operation *scope =
2743c3810e7SNicolas Vasilache op->getParentWithTrait<OpTrait::AutomaticAllocationScope>();
2753c3810e7SNicolas Vasilache assert(scope && "Expected op to be inside automatic allocation scope");
2763c3810e7SNicolas Vasilache return scope;
2773c3810e7SNicolas Vasilache }
2783c3810e7SNicolas Vasilache
279a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present).
280a088bed4SMatthias Springer template <typename OpTy>
allocBuffers(OpBuilder & b,OpTy xferOp)2816825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) {
2826825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
283a088bed4SMatthias Springer OpBuilder::InsertionGuard guard(b);
2843c3810e7SNicolas Vasilache Operation *scope = getAutomaticAllocationScope(xferOp);
2853c3810e7SNicolas Vasilache assert(scope->getNumRegions() == 1 &&
2863c3810e7SNicolas Vasilache "AutomaticAllocationScope with >1 regions");
287a088bed4SMatthias Springer b.setInsertionPointToStart(&scope->getRegion(0).front());
288a088bed4SMatthias Springer
289a088bed4SMatthias Springer BufferAllocs result;
290a088bed4SMatthias Springer auto bufferType = MemRefType::get({}, xferOp.getVectorType());
2916825bfe2SNicolas Vasilache result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType);
292a088bed4SMatthias Springer
2937c38fd60SJacques Pienaar if (xferOp.getMask()) {
2947c38fd60SJacques Pienaar auto maskType = MemRefType::get({}, xferOp.getMask().getType());
2956825bfe2SNicolas Vasilache auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType);
296fb7ec1f1SMatthias Springer b.setInsertionPoint(xferOp);
2977c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer);
2986825bfe2SNicolas Vasilache result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer);
299a088bed4SMatthias Springer }
300a088bed4SMatthias Springer
301a088bed4SMatthias Springer return result;
302a088bed4SMatthias Springer }
303a088bed4SMatthias Springer
304a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from
305a088bed4SMatthias Springer /// the VectorType into the MemRefType.
306a088bed4SMatthias Springer ///
307a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>>
unpackOneDim(MemRefType type)308a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) {
309a088bed4SMatthias Springer auto vectorType = type.getElementType().dyn_cast<VectorType>();
310a088bed4SMatthias Springer auto memrefShape = type.getShape();
311a088bed4SMatthias Springer SmallVector<int64_t, 8> newMemrefShape;
312a088bed4SMatthias Springer newMemrefShape.append(memrefShape.begin(), memrefShape.end());
313a088bed4SMatthias Springer newMemrefShape.push_back(vectorType.getDimSize(0));
314a088bed4SMatthias Springer return MemRefType::get(newMemrefShape,
315a088bed4SMatthias Springer VectorType::get(vectorType.getShape().drop_front(),
316a088bed4SMatthias Springer vectorType.getElementType()));
317a088bed4SMatthias Springer }
318a088bed4SMatthias Springer
3190f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This
3200f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer.
3210f241638SMatthias Springer template <typename OpTy>
getMaskBuffer(OpTy xferOp)3220f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) {
3237c38fd60SJacques Pienaar assert(xferOp.getMask() && "Expected that transfer op has mask");
3247c38fd60SJacques Pienaar auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>();
3250f241638SMatthias Springer assert(loadOp && "Expected transfer op mask produced by LoadOp");
3260f241638SMatthias Springer return loadOp.getMemRef();
3270f241638SMatthias Springer }
3280f241638SMatthias Springer
3290f241638SMatthias Springer /// Codegen strategy, depending on the operation.
3300f241638SMatthias Springer template <typename OpTy>
3310f241638SMatthias Springer struct Strategy;
3320f241638SMatthias Springer
3330f241638SMatthias Springer /// Code strategy for vector TransferReadOp.
3344ead2cf7SAlex Zinenko template <>
3350f241638SMatthias Springer struct Strategy<TransferReadOp> {
3360f241638SMatthias Springer /// Find the StoreOp that is used for writing the current TransferReadOp's
3370f241638SMatthias Springer /// result to the temporary buffer allocation.
getStoreOp__anon4d9edda10111::lowering_n_d::Strategy3380f241638SMatthias Springer static memref::StoreOp getStoreOp(TransferReadOp xferOp) {
3390f241638SMatthias Springer assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp");
3400f241638SMatthias Springer auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner());
3410f241638SMatthias Springer assert(storeOp && "Expected TransferReadOp result used by StoreOp");
3420f241638SMatthias Springer return storeOp;
3437c3c5b11SNicolas Vasilache }
3444ead2cf7SAlex Zinenko
3450f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferReadOps are
3460f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast
3470f241638SMatthias Springer /// of the buffer allocation:
3480f241638SMatthias Springer /// ```
3490f241638SMatthias Springer /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ...
3500f241638SMatthias Springer /// memref.store %vec, %buf[...] ...
3510f241638SMatthias Springer /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy3520f241638SMatthias Springer static Value getBuffer(TransferReadOp xferOp) {
3530f241638SMatthias Springer return getStoreOp(xferOp).getMemRef();
3541870e787SNicolas Vasilache }
3550f241638SMatthias Springer
3560f241638SMatthias Springer /// Retrieve the indices of the current StoreOp that stores into the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy3570f241638SMatthias Springer static void getBufferIndices(TransferReadOp xferOp,
3580f241638SMatthias Springer SmallVector<Value, 8> &indices) {
3590f241638SMatthias Springer auto storeOp = getStoreOp(xferOp);
360136d746eSJacques Pienaar auto prevIndices = memref::StoreOpAdaptor(storeOp).getIndices();
3610f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end());
3620f241638SMatthias Springer }
3630f241638SMatthias Springer
3640f241638SMatthias Springer /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds
3650f241638SMatthias Springer /// accesses on the to-be-unpacked dimension.
3660f241638SMatthias Springer ///
3670f241638SMatthias Springer /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration
3680f241638SMatthias Springer /// variable `iv`.
3690f241638SMatthias Springer /// 2. Store the result into the (already `vector.type_cast`ed) buffer.
3700f241638SMatthias Springer ///
3710f241638SMatthias Springer /// E.g.:
3720f241638SMatthias Springer /// ```
3730f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst
3740f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4x3xf32>
3750f241638SMatthias Springer /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>>
3760f241638SMatthias Springer /// ```
3770f241638SMatthias Springer /// Is rewritten to:
3780f241638SMatthias Springer /// ```
3790f241638SMatthias Springer /// %casted = vector.type_cast %buf
3800f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
3810f241638SMatthias Springer /// for %j = 0 to 4 {
3820f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst
3830f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<3xf32>
3840f241638SMatthias Springer /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>>
3850f241638SMatthias Springer /// }
3860f241638SMatthias Springer /// ```
3870f241638SMatthias Springer ///
3880f241638SMatthias Springer /// Note: The loop and type cast are generated in TransferOpConversion.
3890f241638SMatthias Springer /// The original TransferReadOp and store op are deleted in `cleanup`.
3900f241638SMatthias Springer /// Note: The `mask` operand is set in TransferOpConversion.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy3916825bfe2SNicolas Vasilache static TransferReadOp rewriteOp(OpBuilder &b,
3922ca887deSMatthias Springer VectorTransferToSCFOptions options,
393558e7401SMatthias Springer TransferReadOp xferOp, Value buffer, Value iv,
394558e7401SMatthias Springer ValueRange /*loopState*/) {
3950f241638SMatthias Springer SmallVector<Value, 8> storeIndices;
3960f241638SMatthias Springer getBufferIndices(xferOp, storeIndices);
3970f241638SMatthias Springer storeIndices.push_back(iv);
3980f241638SMatthias Springer
3990f241638SMatthias Springer SmallVector<Value, 8> xferIndices;
4006825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices);
4010f241638SMatthias Springer
4026825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
4030f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4040f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4057c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
4066825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>(
4077c38fd60SJacques Pienaar loc, vecType, xferOp.getSource(), xferIndices,
4087c38fd60SJacques Pienaar AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
4097c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr);
4100f241638SMatthias Springer
4116825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank);
4120f241638SMatthias Springer
4137c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices);
4146825bfe2SNicolas Vasilache return newXferOp;
4150f241638SMatthias Springer }
4160f241638SMatthias Springer
4170f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write
4180f241638SMatthias Springer /// padding value to the temporary buffer.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy419558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp,
420558e7401SMatthias Springer Value buffer, Value iv,
421558e7401SMatthias Springer ValueRange /*loopState*/) {
4220f241638SMatthias Springer SmallVector<Value, 8> storeIndices;
4230f241638SMatthias Springer getBufferIndices(xferOp, storeIndices);
4240f241638SMatthias Springer storeIndices.push_back(iv);
4250f241638SMatthias Springer
4266825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
4270f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>();
4280f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>();
4297c38fd60SJacques Pienaar auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding());
4306825bfe2SNicolas Vasilache b.create<memref::StoreOp>(loc, vec, buffer, storeIndices);
431558e7401SMatthias Springer
432558e7401SMatthias Springer return Value();
4330f241638SMatthias Springer }
4340f241638SMatthias Springer
4350f241638SMatthias Springer /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy436558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp,
437558e7401SMatthias Springer scf::ForOp /*forOp*/) {
4380f241638SMatthias Springer rewriter.eraseOp(getStoreOp(xferOp));
4390f241638SMatthias Springer rewriter.eraseOp(xferOp);
4400f241638SMatthias Springer }
441558e7401SMatthias Springer
442558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy443558e7401SMatthias Springer static Value initialLoopState(TransferReadOp xferOp) { return Value(); }
4444ead2cf7SAlex Zinenko };
4457c3c5b11SNicolas Vasilache
4460f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp.
4470f241638SMatthias Springer template <>
4480f241638SMatthias Springer struct Strategy<TransferWriteOp> {
4490f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferWriteOps are
4500f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast
4510f241638SMatthias Springer /// of the buffer allocation:
4520f241638SMatthias Springer /// ```
4530f241638SMatthias Springer /// %vec = memref.load %buf[...] ...
4540f241638SMatthias Springer /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ...
4550f241638SMatthias Springer /// ```
getBuffer__anon4d9edda10111::lowering_n_d::Strategy4560f241638SMatthias Springer static Value getBuffer(TransferWriteOp xferOp) {
4577c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
4580f241638SMatthias Springer assert(loadOp && "Expected transfer op vector produced by LoadOp");
4590f241638SMatthias Springer return loadOp.getMemRef();
4607c3c5b11SNicolas Vasilache }
4614ead2cf7SAlex Zinenko
4620f241638SMatthias Springer /// Retrieve the indices of the current LoadOp that loads from the buffer.
getBufferIndices__anon4d9edda10111::lowering_n_d::Strategy4630f241638SMatthias Springer static void getBufferIndices(TransferWriteOp xferOp,
4640f241638SMatthias Springer SmallVector<Value, 8> &indices) {
4657c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>();
466136d746eSJacques Pienaar auto prevIndices = memref::LoadOpAdaptor(loadOp).getIndices();
4670f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end());
4680f241638SMatthias Springer }
4690f241638SMatthias Springer
4700f241638SMatthias Springer /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds
4710f241638SMatthias Springer /// accesses on the to-be-unpacked dimension.
4720f241638SMatthias Springer ///
4730f241638SMatthias Springer /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer,
4740f241638SMatthias Springer /// using the loop iteration variable `iv`.
4750f241638SMatthias Springer /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back
4760f241638SMatthias Springer /// to memory.
4770f241638SMatthias Springer ///
4780f241638SMatthias Springer /// Note: For more details, see comments on Strategy<TransferReadOp>.
rewriteOp__anon4d9edda10111::lowering_n_d::Strategy4796825bfe2SNicolas Vasilache static TransferWriteOp rewriteOp(OpBuilder &b,
4802ca887deSMatthias Springer VectorTransferToSCFOptions options,
4812ca887deSMatthias Springer TransferWriteOp xferOp, Value buffer,
482558e7401SMatthias Springer Value iv, ValueRange loopState) {
4830f241638SMatthias Springer SmallVector<Value, 8> loadIndices;
4840f241638SMatthias Springer getBufferIndices(xferOp, loadIndices);
4850f241638SMatthias Springer loadIndices.push_back(iv);
4860f241638SMatthias Springer
4870f241638SMatthias Springer SmallVector<Value, 8> xferIndices;
4886825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices);
4890f241638SMatthias Springer
4906825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
4916825bfe2SNicolas Vasilache auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices);
4927c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
4937c38fd60SJacques Pienaar auto source = loopState.empty() ? xferOp.getSource() : loopState[0];
494558e7401SMatthias Springer Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
4956825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>(
496558e7401SMatthias Springer loc, type, vec, source, xferIndices,
4976825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
4980f241638SMatthias Springer inBoundsAttr);
4990f241638SMatthias Springer
5006825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank);
5010f241638SMatthias Springer
5026825bfe2SNicolas Vasilache return newXferOp;
5030f241638SMatthias Springer }
5040f241638SMatthias Springer
5050f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension.
handleOutOfBoundsDim__anon4d9edda10111::lowering_n_d::Strategy506558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp,
507558e7401SMatthias Springer Value buffer, Value iv,
508558e7401SMatthias Springer ValueRange loopState) {
509558e7401SMatthias Springer return isTensorOp(xferOp) ? loopState[0] : Value();
510558e7401SMatthias Springer }
5110f241638SMatthias Springer
5120f241638SMatthias Springer /// Cleanup after rewriting the op.
cleanup__anon4d9edda10111::lowering_n_d::Strategy513558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp,
514558e7401SMatthias Springer scf::ForOp forOp) {
515558e7401SMatthias Springer if (isTensorOp(xferOp)) {
516558e7401SMatthias Springer assert(forOp->getNumResults() == 1 && "Expected one for loop result");
517558e7401SMatthias Springer rewriter.replaceOp(xferOp, forOp->getResult(0));
518558e7401SMatthias Springer } else {
5190f241638SMatthias Springer rewriter.eraseOp(xferOp);
5200f241638SMatthias Springer }
521558e7401SMatthias Springer }
522558e7401SMatthias Springer
523558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop.
initialLoopState__anon4d9edda10111::lowering_n_d::Strategy524558e7401SMatthias Springer static Value initialLoopState(TransferWriteOp xferOp) {
5257c38fd60SJacques Pienaar return isTensorOp(xferOp) ? xferOp.getSource() : Value();
526558e7401SMatthias Springer }
5270f241638SMatthias Springer };
5280f241638SMatthias Springer
5290f241638SMatthias Springer template <typename OpTy>
checkPrepareXferOp(OpTy xferOp,VectorTransferToSCFOptions options)530fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp,
531fb7ec1f1SMatthias Springer VectorTransferToSCFOptions options) {
5320f241638SMatthias Springer if (xferOp->hasAttr(kPassLabel))
5330f241638SMatthias Springer return failure();
534fb7ec1f1SMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank)
5350f241638SMatthias Springer return failure();
536558e7401SMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors)
5378fb48979SMatthias Springer return failure();
538f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm.
539f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() !=
540f718a53dSMatthias Springer xferOp.getShapedType().getElementType())
541f718a53dSMatthias Springer return failure();
5420f241638SMatthias Springer return success();
5430f241638SMatthias Springer }
5440f241638SMatthias Springer
5450f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering.
5460f241638SMatthias Springer ///
5470f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5480f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering.
5490f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer.
5500f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the
5510f241638SMatthias Springer /// original TransferReadOp with this load.
5520f241638SMatthias Springer ///
5530f241638SMatthias Springer /// E.g.:
5540f241638SMatthias Springer /// ```
5550f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst
5560f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32>
5570f241638SMatthias Springer /// ```
5580f241638SMatthias Springer /// is rewritten to:
5590f241638SMatthias Springer /// ```
5600f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
5610f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst
5620f241638SMatthias Springer /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32>
5630f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>>
5640f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>>
5650f241638SMatthias Springer /// ```
5660f241638SMatthias Springer ///
5670f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
5682ca887deSMatthias Springer struct PrepareTransferReadConversion
5692ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> {
5702ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
5710f241638SMatthias Springer
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferReadConversion5720f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp,
5730f241638SMatthias Springer PatternRewriter &rewriter) const override {
574fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed())
5750f241638SMatthias Springer return failure();
5760f241638SMatthias Springer
5776825bfe2SNicolas Vasilache auto buffers = allocBuffers(rewriter, xferOp);
5780f241638SMatthias Springer auto *newXfer = rewriter.clone(*xferOp.getOperation());
5790f241638SMatthias Springer newXfer->setAttr(kPassLabel, rewriter.getUnitAttr());
5807c38fd60SJacques Pienaar if (xferOp.getMask()) {
5817c38fd60SJacques Pienaar dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign(
5820f241638SMatthias Springer buffers.maskBuffer);
5830f241638SMatthias Springer }
5840f241638SMatthias Springer
5856825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
5866825bfe2SNicolas Vasilache rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0),
5876825bfe2SNicolas Vasilache buffers.dataBuffer);
5880f241638SMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer);
5894ead2cf7SAlex Zinenko
5904ead2cf7SAlex Zinenko return success();
5914ead2cf7SAlex Zinenko }
5920f241638SMatthias Springer };
5930f241638SMatthias Springer
5940f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering.
5950f241638SMatthias Springer ///
5960f241638SMatthias Springer /// 1. Allocate a temporary buffer.
5970f241638SMatthias Springer /// 2. Store the vector into the buffer.
5980f241638SMatthias Springer /// 3. Load the vector from the buffer again.
5990f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op,
6000f241638SMatthias Springer /// marking it eligible for progressive lowering via TransferOpConversion.
6010f241638SMatthias Springer ///
6020f241638SMatthias Springer /// E.g.:
6030f241638SMatthias Springer /// ```
6040f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
6050f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32>
6060f241638SMatthias Springer /// ```
6070f241638SMatthias Springer /// is rewritten to:
6080f241638SMatthias Springer /// ```
6090f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>>
6100f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>>
6110f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>>
6120f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ }
6130f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32>
6140f241638SMatthias Springer /// ```
6150f241638SMatthias Springer ///
6160f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand.
6170f241638SMatthias Springer struct PrepareTransferWriteConversion
6182ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> {
6192ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
6200f241638SMatthias Springer
matchAndRewrite__anon4d9edda10111::lowering_n_d::PrepareTransferWriteConversion6210f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp,
6220f241638SMatthias Springer PatternRewriter &rewriter) const override {
623fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed())
6240f241638SMatthias Springer return failure();
6250f241638SMatthias Springer
6266825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
6276825bfe2SNicolas Vasilache auto buffers = allocBuffers(rewriter, xferOp);
6287c38fd60SJacques Pienaar rewriter.create<memref::StoreOp>(loc, xferOp.getVector(),
6297c38fd60SJacques Pienaar buffers.dataBuffer);
6306825bfe2SNicolas Vasilache auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer);
6310f241638SMatthias Springer rewriter.updateRootInPlace(xferOp, [&]() {
6327c38fd60SJacques Pienaar xferOp.getVectorMutable().assign(loadedVec);
6330f241638SMatthias Springer xferOp->setAttr(kPassLabel, rewriter.getUnitAttr());
6340f241638SMatthias Springer });
6350f241638SMatthias Springer
6367c38fd60SJacques Pienaar if (xferOp.getMask()) {
6377c38fd60SJacques Pienaar rewriter.updateRootInPlace(xferOp, [&]() {
6387c38fd60SJacques Pienaar xferOp.getMaskMutable().assign(buffers.maskBuffer);
6397c38fd60SJacques Pienaar });
6400f241638SMatthias Springer }
6410f241638SMatthias Springer
6420f241638SMatthias Springer return success();
6430f241638SMatthias Springer }
6440f241638SMatthias Springer };
6450f241638SMatthias Springer
6460f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension.
6470f241638SMatthias Springer ///
6480f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer
6490f241638SMatthias Springer /// to that new type. E.g.:
6500f241638SMatthias Springer /// ```
6510f241638SMatthias Springer /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>>
6520f241638SMatthias Springer /// vector.transfer_write %vec ...
6530f241638SMatthias Springer /// ```
6540f241638SMatthias Springer /// The following cast is generated:
6550f241638SMatthias Springer /// ```
6560f241638SMatthias Springer /// %casted = vector.type_cast %0
6570f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>>
6580f241638SMatthias Springer /// ```
6590f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the
6600f241638SMatthias Springer /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be
6610f241638SMatthias Springer /// out-of-bounds, generate an if-check and handle both cases separately.
6620f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>.
663558e7401SMatthias Springer ///
664558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor
665558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated
666558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.:
667558e7401SMatthias Springer /// ```
668558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 {
669558e7401SMatthias Springer /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>>
670558e7401SMatthias Springer /// %1 = vector.transfer_write %0, %source[...]
671558e7401SMatthias Springer /// : vector<4x3xf32>, tensor<5x4x3xf32>
672558e7401SMatthias Springer /// scf.yield %1 : tensor<5x4x3xf32>
673558e7401SMatthias Springer /// }
674558e7401SMatthias Springer /// ```
6750f241638SMatthias Springer template <typename OpTy>
6762ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
6772ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
6780f241638SMatthias Springer
initialize__anon4d9edda10111::lowering_n_d::TransferOpConversion679700b64dcSMatthias Springer void initialize() {
680700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion
681700b64dcSMatthias Springer // bounded as the rank is strictly decreasing.
682700b64dcSMatthias Springer this->setHasBoundedRewriteRecursion();
683700b64dcSMatthias Springer }
684700b64dcSMatthias Springer
matchAndRewrite__anon4d9edda10111::lowering_n_d::TransferOpConversion6850f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp,
6860f241638SMatthias Springer PatternRewriter &rewriter) const override {
6870f241638SMatthias Springer if (!xferOp->hasAttr(kPassLabel))
6880f241638SMatthias Springer return failure();
6890f241638SMatthias Springer
6900f241638SMatthias Springer // Find and cast data buffer. How the buffer can be found depends on OpTy.
6916825bfe2SNicolas Vasilache ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
6920f241638SMatthias Springer auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
6930f241638SMatthias Springer auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>();
6940f241638SMatthias Springer auto castedDataType = unpackOneDim(dataBufferType);
6956825bfe2SNicolas Vasilache auto castedDataBuffer =
6966825bfe2SNicolas Vasilache locB.create<vector::TypeCastOp>(castedDataType, dataBuffer);
6970f241638SMatthias Springer
6980f241638SMatthias Springer // If the xferOp has a mask: Find and cast mask buffer.
6990f241638SMatthias Springer Value castedMaskBuffer;
7007c38fd60SJacques Pienaar if (xferOp.getMask()) {
7010f241638SMatthias Springer auto maskBuffer = getMaskBuffer(xferOp);
7020f241638SMatthias Springer auto maskBufferType =
7030f241638SMatthias Springer maskBuffer.getType().template dyn_cast<MemRefType>();
7040f241638SMatthias Springer if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
7050f241638SMatthias Springer // Do not unpack a dimension of the mask, if:
7060f241638SMatthias Springer // * To-be-unpacked transfer op dimension is a broadcast.
7070f241638SMatthias Springer // * Mask is 1D, i.e., the mask cannot be further unpacked.
7080f241638SMatthias Springer // (That means that all remaining dimensions of the transfer op must
7090f241638SMatthias Springer // be broadcasted.)
7100f241638SMatthias Springer castedMaskBuffer = maskBuffer;
7110f241638SMatthias Springer } else {
7120f241638SMatthias Springer auto castedMaskType = unpackOneDim(maskBufferType);
7136825bfe2SNicolas Vasilache castedMaskBuffer =
7146825bfe2SNicolas Vasilache locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
7150f241638SMatthias Springer }
7160f241638SMatthias Springer }
7170f241638SMatthias Springer
7180f241638SMatthias Springer // Loop bounds and step.
719a54f4eaeSMogball auto lb = locB.create<arith::ConstantIndexOp>(0);
720a54f4eaeSMogball auto ub = locB.create<arith::ConstantIndexOp>(
7216825bfe2SNicolas Vasilache castedDataType.getDimSize(castedDataType.getRank() - 1));
722a54f4eaeSMogball auto step = locB.create<arith::ConstantIndexOp>(1);
723558e7401SMatthias Springer // TransferWriteOps that operate on tensors return the modified tensor and
724558e7401SMatthias Springer // require a loop state.
725558e7401SMatthias Springer auto loopState = Strategy<OpTy>::initialLoopState(xferOp);
7260f241638SMatthias Springer
7270f241638SMatthias Springer // Generate for loop.
728558e7401SMatthias Springer auto result = locB.create<scf::ForOp>(
729558e7401SMatthias Springer lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
730558e7401SMatthias Springer [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
731558e7401SMatthias Springer Type stateType = loopState.empty() ? Type() : loopState[0].getType();
732558e7401SMatthias Springer
733558e7401SMatthias Springer auto result = generateInBoundsCheck(
7346825bfe2SNicolas Vasilache b, xferOp, iv, unpackedDim(xferOp),
735558e7401SMatthias Springer stateType ? TypeRange(stateType) : TypeRange(),
7360f241638SMatthias Springer /*inBoundsCase=*/
7376825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
7380f241638SMatthias Springer // Create new transfer op.
7392ca887deSMatthias Springer OpTy newXfer = Strategy<OpTy>::rewriteOp(
740558e7401SMatthias Springer b, this->options, xferOp, castedDataBuffer, iv, loopState);
7410f241638SMatthias Springer
7420f241638SMatthias Springer // If old transfer op has a mask: Set mask on new transfer op.
7430f241638SMatthias Springer // Special case: If the mask of the old transfer op is 1D and
7440f241638SMatthias Springer // the
7450f241638SMatthias Springer // unpacked dim is not a broadcast, no mask is
7460f241638SMatthias Springer // needed on the new transfer op.
7477c38fd60SJacques Pienaar if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
7480f241638SMatthias Springer xferOp.getMaskType().getRank() > 1)) {
7490f241638SMatthias Springer OpBuilder::InsertionGuard guard(b);
7500f241638SMatthias Springer b.setInsertionPoint(newXfer); // Insert load before newXfer.
7510f241638SMatthias Springer
7520f241638SMatthias Springer SmallVector<Value, 8> loadIndices;
7530f241638SMatthias Springer Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
7540f241638SMatthias Springer // In case of broadcast: Use same indices to load from memref
7550f241638SMatthias Springer // as before.
7560f241638SMatthias Springer if (!xferOp.isBroadcastDim(0))
7570f241638SMatthias Springer loadIndices.push_back(iv);
7580f241638SMatthias Springer
7596825bfe2SNicolas Vasilache auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
7606825bfe2SNicolas Vasilache loadIndices);
7617c38fd60SJacques Pienaar rewriter.updateRootInPlace(newXfer, [&]() {
7627c38fd60SJacques Pienaar newXfer.getMaskMutable().assign(mask);
7637c38fd60SJacques Pienaar });
7640f241638SMatthias Springer }
765558e7401SMatthias Springer
766558e7401SMatthias Springer return loopState.empty() ? Value() : newXfer->getResult(0);
7670f241638SMatthias Springer },
7680f241638SMatthias Springer /*outOfBoundsCase=*/
7690f241638SMatthias Springer [&](OpBuilder &b, Location /*loc*/) {
770558e7401SMatthias Springer return Strategy<OpTy>::handleOutOfBoundsDim(
771558e7401SMatthias Springer b, xferOp, castedDataBuffer, iv, loopState);
7720f241638SMatthias Springer });
7730f241638SMatthias Springer
774558e7401SMatthias Springer maybeYieldValue(b, loc, !loopState.empty(), result);
775558e7401SMatthias Springer });
776558e7401SMatthias Springer
777558e7401SMatthias Springer Strategy<OpTy>::cleanup(rewriter, xferOp, result);
7780f241638SMatthias Springer return success();
7790f241638SMatthias Springer }
7800f241638SMatthias Springer };
7810f241638SMatthias Springer
782a088bed4SMatthias Springer } // namespace lowering_n_d
783a088bed4SMatthias Springer
784a088bed4SMatthias Springer namespace lowering_n_d_unrolled {
785a088bed4SMatthias Springer
7860f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer
7870f241638SMatthias Springer /// op (for the current iteration `i`) and assign it.
7880f241638SMatthias Springer template <typename OpTy>
maybeAssignMask(OpBuilder & b,OpTy xferOp,OpTy newXferOp,int64_t i)7896825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp,
7900f241638SMatthias Springer int64_t i) {
7917c38fd60SJacques Pienaar if (!xferOp.getMask())
7920f241638SMatthias Springer return;
7930f241638SMatthias Springer
7940f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) {
7950f241638SMatthias Springer // To-be-unpacked dimension is a broadcast, which does not have a
7960f241638SMatthias Springer // corresponding mask dimension. Mask attribute remains unchanged.
7977c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(xferOp.getMask());
7980f241638SMatthias Springer return;
7990f241638SMatthias Springer }
8000f241638SMatthias Springer
8010f241638SMatthias Springer if (xferOp.getMaskType().getRank() > 1) {
8020f241638SMatthias Springer // Unpack one dimension of the mask.
8036825bfe2SNicolas Vasilache OpBuilder::InsertionGuard guard(b);
8046825bfe2SNicolas Vasilache b.setInsertionPoint(newXferOp); // Insert load before newXfer.
8050f241638SMatthias Springer
8060f241638SMatthias Springer llvm::SmallVector<int64_t, 1> indices({i});
8076825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
8087c38fd60SJacques Pienaar auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices);
8097c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(newMask);
8100f241638SMatthias Springer }
8110f241638SMatthias Springer
8120f241638SMatthias Springer // If we end up here: The mask of the old transfer op is 1D and the unpacked
8130f241638SMatthias Springer // dim is not a broadcast, so no mask is needed on the new transfer op.
8140f241638SMatthias Springer // `generateInBoundsCheck` will have evaluated the mask already.
8150f241638SMatthias Springer }
8160f241638SMatthias Springer
8170f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one
8180f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no
8190f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
8200f241638SMatthias Springer ///
8210f241638SMatthias Springer /// ```
8220f241638SMatthias Springer /// E.g.:
8230f241638SMatthias Springer /// ```
8240f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding
8250f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<5x4xf32>
8260f241638SMatthias Springer /// ```
8270f241638SMatthias Springer /// is rewritten to IR such as (simplified):
8280f241638SMatthias Springer /// ```
8290f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32>
8300f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding
8310f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32>
8320f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32>
8330f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding
8340f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32>
8350f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32>
8360f241638SMatthias Springer /// ...
8370f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding
8380f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32>
8390f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32>
8400f241638SMatthias Springer /// ```
8410f241638SMatthias Springer ///
8420f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp
8430f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created.
8440f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector.
8452ca887deSMatthias Springer struct UnrollTransferReadConversion
8462ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> {
8472ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern;
8480f241638SMatthias Springer
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion849700b64dcSMatthias Springer void initialize() {
850700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion
851700b64dcSMatthias Springer // bounded as the rank is strictly decreasing.
852700b64dcSMatthias Springer setHasBoundedRewriteRecursion();
853700b64dcSMatthias Springer }
854700b64dcSMatthias Springer
8550f241638SMatthias Springer /// Return the vector into which the newly created TransferReadOp results
8560f241638SMatthias Springer /// are inserted.
getResultVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8570f241638SMatthias Springer Value getResultVector(TransferReadOp xferOp,
8580f241638SMatthias Springer PatternRewriter &rewriter) const {
8590f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp))
8607c38fd60SJacques Pienaar return insertOp.getDest();
8616825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
8626a8ba318SRiver Riddle return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(),
8637c38fd60SJacques Pienaar xferOp.getPadding());
8640f241638SMatthias Springer }
8650f241638SMatthias Springer
8660f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a
8670f241638SMatthias Springer /// vector::InsertOp, return that operation.
getInsertOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8680f241638SMatthias Springer vector::InsertOp getInsertOp(TransferReadOp xferOp) const {
8690f241638SMatthias Springer if (xferOp->hasOneUse()) {
8700f241638SMatthias Springer Operation *xferOpUser = *xferOp->getUsers().begin();
8710f241638SMatthias Springer if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser))
8720f241638SMatthias Springer return insertOp;
8730f241638SMatthias Springer }
8740f241638SMatthias Springer
8750f241638SMatthias Springer return vector::InsertOp();
8760f241638SMatthias Springer }
8770f241638SMatthias Springer
8780f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a
8790f241638SMatthias Springer /// vector::InsertOp, return that operation's indices.
getInsertionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8800f241638SMatthias Springer void getInsertionIndices(TransferReadOp xferOp,
8810f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const {
8820f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) {
883*c730f9a1SKazu Hirata for (Attribute attr : insertOp.getPosition())
8840f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
8850f241638SMatthias Springer }
8860f241638SMatthias Springer }
8870f241638SMatthias Springer
8880f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
8890f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferReadConversion8900f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp,
8910f241638SMatthias Springer PatternRewriter &rewriter) const override {
8922ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank)
8930f241638SMatthias Springer return failure();
894bd20756dSMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors)
8958fb48979SMatthias Springer return failure();
896f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm.
897f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() !=
898f718a53dSMatthias Springer xferOp.getShapedType().getElementType())
899f718a53dSMatthias Springer return failure();
9000f241638SMatthias Springer
9010f241638SMatthias Springer auto insertOp = getInsertOp(xferOp);
9020f241638SMatthias Springer auto vec = getResultVector(xferOp, rewriter);
9030f241638SMatthias Springer auto vecType = vec.getType().dyn_cast<VectorType>();
9040f241638SMatthias Springer auto xferVecType = xferOp.getVectorType();
9050f241638SMatthias Springer auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(),
9060f241638SMatthias Springer xferVecType.getElementType());
9070f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0];
9080f241638SMatthias Springer
9090f241638SMatthias Springer // Generate fully unrolled loop of transfer ops.
9106825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
9110f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) {
912a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
9130f241638SMatthias Springer
9140f241638SMatthias Springer vec = generateInBoundsCheck(
9156825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType),
9160f241638SMatthias Springer /*inBoundsCase=*/
9170f241638SMatthias Springer [&](OpBuilder &b, Location loc) {
9180f241638SMatthias Springer // Indices for the new transfer op.
9190f241638SMatthias Springer SmallVector<Value, 8> xferIndices;
9206825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices);
9210f241638SMatthias Springer
9220f241638SMatthias Springer // Indices for the new vector.insert op.
9230f241638SMatthias Springer SmallVector<int64_t, 8> insertionIndices;
9240f241638SMatthias Springer getInsertionIndices(xferOp, insertionIndices);
9250f241638SMatthias Springer insertionIndices.push_back(i);
9260f241638SMatthias Springer
9277c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
9286825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>(
9297c38fd60SJacques Pienaar loc, newXferVecType, xferOp.getSource(), xferIndices,
9306825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)),
9317c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr);
9320f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i);
9336825bfe2SNicolas Vasilache return b.create<vector::InsertOp>(loc, newXferOp, vec,
9346825bfe2SNicolas Vasilache insertionIndices);
9350f241638SMatthias Springer },
9360f241638SMatthias Springer /*outOfBoundsCase=*/
9370f241638SMatthias Springer [&](OpBuilder &b, Location loc) {
9380f241638SMatthias Springer // Loop through original (unmodified) vector.
9390f241638SMatthias Springer return vec;
9400f241638SMatthias Springer });
9410f241638SMatthias Springer }
9420f241638SMatthias Springer
9430f241638SMatthias Springer if (insertOp) {
9440f241638SMatthias Springer // Rewrite single user of the old TransferReadOp, which was an InsertOp.
9450f241638SMatthias Springer rewriter.replaceOp(insertOp, vec);
9460f241638SMatthias Springer rewriter.eraseOp(xferOp);
9470f241638SMatthias Springer } else {
9480f241638SMatthias Springer rewriter.replaceOp(xferOp, vec);
9490f241638SMatthias Springer }
9500f241638SMatthias Springer
9510f241638SMatthias Springer return success();
9520f241638SMatthias Springer }
9530f241638SMatthias Springer };
9540f241638SMatthias Springer
9550f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one
9560f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no
9570f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled.
9580f241638SMatthias Springer ///
9590f241638SMatthias Springer /// ```
9600f241638SMatthias Springer /// E.g.:
9610f241638SMatthias Springer /// ```
9620f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c]
9630f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32>
9640f241638SMatthias Springer /// ```
9650f241638SMatthias Springer /// is rewritten to IR such as (simplified):
9660f241638SMatthias Springer /// ```
9670f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32>
9680f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...>
9690f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32>
9700f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...>
9710f241638SMatthias Springer /// ...
9720f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32>
9730f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...>
9740f241638SMatthias Springer /// ```
9750f241638SMatthias Springer ///
9760f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp
9770f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract
9780f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By
9790f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during
9800f241638SMatthias Springer /// recursive application of this pattern will be minimal.
9810f241638SMatthias Springer struct UnrollTransferWriteConversion
9822ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> {
9832ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern;
9840f241638SMatthias Springer
initialize__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion985700b64dcSMatthias Springer void initialize() {
986700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion
987700b64dcSMatthias Springer // bounded as the rank is strictly decreasing.
988700b64dcSMatthias Springer setHasBoundedRewriteRecursion();
989700b64dcSMatthias Springer }
990700b64dcSMatthias Springer
9910f241638SMatthias Springer /// Return the vector from which newly generated ExtracOps will extract.
getDataVector__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion9920f241638SMatthias Springer Value getDataVector(TransferWriteOp xferOp) const {
9930f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp))
9947c38fd60SJacques Pienaar return extractOp.getVector();
9957c38fd60SJacques Pienaar return xferOp.getVector();
9960f241638SMatthias Springer }
9970f241638SMatthias Springer
9980f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return it.
getExtractOp__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion9990f241638SMatthias Springer vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const {
10007c38fd60SJacques Pienaar if (auto *op = xferOp.getVector().getDefiningOp())
10010f241638SMatthias Springer return dyn_cast<vector::ExtractOp>(op);
10020f241638SMatthias Springer return vector::ExtractOp();
10030f241638SMatthias Springer }
10040f241638SMatthias Springer
10050f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return its
10060f241638SMatthias Springer /// indices.
getExtractionIndices__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion10070f241638SMatthias Springer void getExtractionIndices(TransferWriteOp xferOp,
10080f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const {
10090f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) {
1010*c730f9a1SKazu Hirata for (Attribute attr : extractOp.getPosition())
10110f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt());
10120f241638SMatthias Springer }
10130f241638SMatthias Springer }
10140f241638SMatthias Springer
10150f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds
10160f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps.
matchAndRewrite__anon4d9edda10111::lowering_n_d_unrolled::UnrollTransferWriteConversion10170f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp,
10180f241638SMatthias Springer PatternRewriter &rewriter) const override {
10192ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank)
10200f241638SMatthias Springer return failure();
1021bd20756dSMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors)
10228fb48979SMatthias Springer return failure();
1023f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm.
1024f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() !=
1025f718a53dSMatthias Springer xferOp.getShapedType().getElementType())
1026f718a53dSMatthias Springer return failure();
10270f241638SMatthias Springer
10280f241638SMatthias Springer auto vec = getDataVector(xferOp);
10290f241638SMatthias Springer auto xferVecType = xferOp.getVectorType();
10300f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0];
10317c38fd60SJacques Pienaar auto source = xferOp.getSource(); // memref or tensor to be written to.
1032bd20756dSMatthias Springer auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type();
10330f241638SMatthias Springer
10340f241638SMatthias Springer // Generate fully unrolled loop of transfer ops.
10356825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
10360f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) {
1037a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i);
10380f241638SMatthias Springer
1039bd20756dSMatthias Springer auto updatedSource = generateInBoundsCheck(
10406825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp),
1041bd20756dSMatthias Springer isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(),
1042bd20756dSMatthias Springer /*inBoundsCase=*/
1043bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) {
10440f241638SMatthias Springer // Indices for the new transfer op.
10450f241638SMatthias Springer SmallVector<Value, 8> xferIndices;
10466825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices);
10470f241638SMatthias Springer
10480f241638SMatthias Springer // Indices for the new vector.extract op.
10490f241638SMatthias Springer SmallVector<int64_t, 8> extractionIndices;
10500f241638SMatthias Springer getExtractionIndices(xferOp, extractionIndices);
10510f241638SMatthias Springer extractionIndices.push_back(i);
10520f241638SMatthias Springer
10536825bfe2SNicolas Vasilache auto extracted =
10546825bfe2SNicolas Vasilache b.create<vector::ExtractOp>(loc, vec, extractionIndices);
10557c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr());
10566825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>(
1057bd20756dSMatthias Springer loc, sourceType, extracted, source, xferIndices,
10586825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(),
10596825bfe2SNicolas Vasilache inBoundsAttr);
10600f241638SMatthias Springer
10610f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i);
1062bd20756dSMatthias Springer
1063bd20756dSMatthias Springer return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value();
1064bd20756dSMatthias Springer },
1065bd20756dSMatthias Springer /*outOfBoundsCase=*/
1066bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) {
1067bd20756dSMatthias Springer return isTensorOp(xferOp) ? source : Value();
10680f241638SMatthias Springer });
1069bd20756dSMatthias Springer
1070bd20756dSMatthias Springer if (isTensorOp(xferOp))
1071bd20756dSMatthias Springer source = updatedSource;
10720f241638SMatthias Springer }
10730f241638SMatthias Springer
1074bd20756dSMatthias Springer if (isTensorOp(xferOp))
1075bd20756dSMatthias Springer rewriter.replaceOp(xferOp, source);
1076bd20756dSMatthias Springer else
10770f241638SMatthias Springer rewriter.eraseOp(xferOp);
1078bd20756dSMatthias Springer
10790f241638SMatthias Springer return success();
10800f241638SMatthias Springer }
10810f241638SMatthias Springer };
10820f241638SMatthias Springer
1083a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled
1084a088bed4SMatthias Springer
1085a088bed4SMatthias Springer namespace lowering_1_d {
1086a088bed4SMatthias Springer
10870f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as
10880f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which
10890f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast.
10900f241638SMatthias Springer template <typename OpTy>
10910f241638SMatthias Springer static Optional<int64_t>
get1dMemrefIndices(OpBuilder & b,OpTy xferOp,Value iv,SmallVector<Value,8> & memrefIndices)10926825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv,
10930f241638SMatthias Springer SmallVector<Value, 8> &memrefIndices) {
10947c38fd60SJacques Pienaar auto indices = xferOp.getIndices();
10957c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap();
1096c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer");
10970f241638SMatthias Springer
10980f241638SMatthias Springer memrefIndices.append(indices.begin(), indices.end());
10990f241638SMatthias Springer assert(map.getNumResults() == 1 &&
11000f241638SMatthias Springer "Expected 1 permutation map result for 1D transfer");
11010f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) {
11026825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
11030f241638SMatthias Springer auto dim = expr.getPosition();
11046825bfe2SNicolas Vasilache AffineExpr d0, d1;
11056825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1);
11066825bfe2SNicolas Vasilache Value offset = memrefIndices[dim];
11076825bfe2SNicolas Vasilache memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv});
11080f241638SMatthias Springer return dim;
11090f241638SMatthias Springer }
11100f241638SMatthias Springer
11110f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) &&
11120f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr");
11130f241638SMatthias Springer return None;
11140f241638SMatthias Springer }
11150f241638SMatthias Springer
11160f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the
11170f241638SMatthias Springer /// operation.
11180f241638SMatthias Springer template <typename OpTy>
11190f241638SMatthias Springer struct Strategy1d;
11200f241638SMatthias Springer
11210f241638SMatthias Springer /// Codegen strategy for TransferReadOp.
11220f241638SMatthias Springer template <>
11230f241638SMatthias Springer struct Strategy1d<TransferReadOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d11246825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc,
11250f241638SMatthias Springer TransferReadOp xferOp, Value iv,
11260f241638SMatthias Springer ValueRange loopState) {
11270f241638SMatthias Springer SmallVector<Value, 8> indices;
11286825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11290f241638SMatthias Springer auto vec = loopState[0];
11300f241638SMatthias Springer
11310f241638SMatthias Springer // In case of out-of-bounds access, leave `vec` as is (was initialized with
11320f241638SMatthias Springer // padding value).
11330f241638SMatthias Springer auto nextVec = generateInBoundsCheck(
11346825bfe2SNicolas Vasilache b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()),
11350f241638SMatthias Springer /*inBoundsCase=*/
11366825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) {
11377c38fd60SJacques Pienaar Value val =
11387c38fd60SJacques Pienaar b.create<memref::LoadOp>(loc, xferOp.getSource(), indices);
11397c5ecc8bSMogball return b.create<vector::InsertElementOp>(loc, val, vec, iv);
11400f241638SMatthias Springer },
11410f241638SMatthias Springer /*outOfBoundsCase=*/
11420f241638SMatthias Springer [&](OpBuilder & /*b*/, Location loc) { return vec; });
11436825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, nextVec);
11440f241638SMatthias Springer }
11450f241638SMatthias Springer
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d11466825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) {
11470f241638SMatthias Springer // Inititalize vector with padding value.
11486825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
11496a8ba318SRiver Riddle return b.create<vector::SplatOp>(loc, xferOp.getVectorType(),
11507c38fd60SJacques Pienaar xferOp.getPadding());
11510f241638SMatthias Springer }
11520f241638SMatthias Springer };
11530f241638SMatthias Springer
11540f241638SMatthias Springer /// Codegen strategy for TransferWriteOp.
11550f241638SMatthias Springer template <>
11560f241638SMatthias Springer struct Strategy1d<TransferWriteOp> {
generateForLoopBody__anon4d9edda10111::lowering_1_d::Strategy1d11576825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc,
11580f241638SMatthias Springer TransferWriteOp xferOp, Value iv,
11590f241638SMatthias Springer ValueRange /*loopState*/) {
11600f241638SMatthias Springer SmallVector<Value, 8> indices;
11616825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices);
11620f241638SMatthias Springer
11630f241638SMatthias Springer // Nothing to do in case of out-of-bounds access.
11640f241638SMatthias Springer generateInBoundsCheck(
11656825bfe2SNicolas Vasilache b, xferOp, iv, dim,
11666825bfe2SNicolas Vasilache /*inBoundsCase=*/[&](OpBuilder &b, Location loc) {
11676825bfe2SNicolas Vasilache auto val =
11687c38fd60SJacques Pienaar b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv);
11697c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices);
11700f241638SMatthias Springer });
11716825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc);
11720f241638SMatthias Springer }
11730f241638SMatthias Springer
initialLoopState__anon4d9edda10111::lowering_1_d::Strategy1d11746825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) {
11756825bfe2SNicolas Vasilache return Value();
11766825bfe2SNicolas Vasilache }
11770f241638SMatthias Springer };
11780f241638SMatthias Springer
11790f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride.
isLastMemrefDimUnitStride(MemRefType type)11800f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) {
11810f241638SMatthias Springer int64_t offset;
11820f241638SMatthias Springer SmallVector<int64_t, 4> strides;
11830f241638SMatthias Springer auto successStrides = getStridesAndOffset(type, strides, offset);
11845017b0f8SMatthias Springer return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
11850f241638SMatthias Springer }
11860f241638SMatthias Springer
11870f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is
11880f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into
11890f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts:
11900f241638SMatthias Springer ///
11910f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension
11920f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast)
11930f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension
11940f241638SMatthias Springer ///
11950f241638SMatthias Springer /// This pattern generates IR as follows:
11960f241638SMatthias Springer ///
11970f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element.
11980f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp,
11990f241638SMatthias Springer /// depending on OpTy.
12000f241638SMatthias Springer ///
12010f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp
12020f241638SMatthias Springer /// can be generated instead of TransferOp1dConversion. Add such a pattern
12030f241638SMatthias Springer /// to ConvertVectorToLLVM.
12040f241638SMatthias Springer ///
12050f241638SMatthias Springer /// E.g.:
12060f241638SMatthias Springer /// ```
12070f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b]
12080f241638SMatthias Springer /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]}
12090f241638SMatthias Springer /// : vector<9xf32>, memref<?x?xf32>
12100f241638SMatthias Springer /// ```
12110f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR:
12120f241638SMatthias Springer /// ```
12130f241638SMatthias Springer /// for i = 0 to 9 {
12140f241638SMatthias Springer /// %t = vector.extractelement %vec[i] : vector<9xf32>
12150f241638SMatthias Springer /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
12160f241638SMatthias Springer /// }
12170f241638SMatthias Springer /// ```
12180f241638SMatthias Springer template <typename OpTy>
12192ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> {
12202ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern;
12210f241638SMatthias Springer
matchAndRewrite__anon4d9edda10111::lowering_1_d::TransferOp1dConversion12220f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp,
12230f241638SMatthias Springer PatternRewriter &rewriter) const override {
1224c537a943SNicolas Vasilache // TODO: support 0-d corner case.
1225c537a943SNicolas Vasilache if (xferOp.getTransferRank() == 0)
1226c537a943SNicolas Vasilache return failure();
12277c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap();
12280f241638SMatthias Springer auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>();
12290f241638SMatthias Springer
12300f241638SMatthias Springer if (!memRefType)
12310f241638SMatthias Springer return failure();
12320f241638SMatthias Springer if (xferOp.getVectorType().getRank() != 1)
12330f241638SMatthias Springer return failure();
12340f241638SMatthias Springer if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType))
12350f241638SMatthias Springer return failure(); // Handled by ConvertVectorToLLVM
12360f241638SMatthias Springer
12370f241638SMatthias Springer // Loop bounds, step, state...
12386825bfe2SNicolas Vasilache Location loc = xferOp.getLoc();
12390f241638SMatthias Springer auto vecType = xferOp.getVectorType();
1240a54f4eaeSMogball auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1241a54f4eaeSMogball auto ub =
1242a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0));
1243a54f4eaeSMogball auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1);
12446825bfe2SNicolas Vasilache auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp);
12450f241638SMatthias Springer
12460f241638SMatthias Springer // Generate for loop.
12470f241638SMatthias Springer rewriter.replaceOpWithNewOp<scf::ForOp>(
12480f241638SMatthias Springer xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(),
12496825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) {
12506825bfe2SNicolas Vasilache Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState);
12510f241638SMatthias Springer });
12520f241638SMatthias Springer
12530f241638SMatthias Springer return success();
12540f241638SMatthias Springer }
12550f241638SMatthias Springer };
12564ead2cf7SAlex Zinenko
1257a088bed4SMatthias Springer } // namespace lowering_1_d
1258df63eedeSBenjamin Kramer } // namespace
1259df63eedeSBenjamin Kramer
populateVectorToSCFConversionPatterns(RewritePatternSet & patterns,const VectorTransferToSCFOptions & options)126047f175b0SRiver Riddle void mlir::populateVectorToSCFConversionPatterns(
1261dc4e913bSChris Lattner RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) {
12620f241638SMatthias Springer if (options.unroll) {
1263a088bed4SMatthias Springer patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion,
1264a088bed4SMatthias Springer lowering_n_d_unrolled::UnrollTransferWriteConversion>(
12652ca887deSMatthias Springer patterns.getContext(), options);
12660f241638SMatthias Springer } else {
1267a088bed4SMatthias Springer patterns.add<lowering_n_d::PrepareTransferReadConversion,
1268a088bed4SMatthias Springer lowering_n_d::PrepareTransferWriteConversion,
1269a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferReadOp>,
1270a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferWriteOp>>(
1271a088bed4SMatthias Springer patterns.getContext(), options);
12720f241638SMatthias Springer }
12730f241638SMatthias Springer
12742ca887deSMatthias Springer if (options.targetRank == 1) {
1275a088bed4SMatthias Springer patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>,
1276a088bed4SMatthias Springer lowering_1_d::TransferOp1dConversion<TransferWriteOp>>(
1277a088bed4SMatthias Springer patterns.getContext(), options);
12780f241638SMatthias Springer }
12794ead2cf7SAlex Zinenko }
12803393cc4cSNicolas Vasilache
12815f9e0466SNicolas Vasilache namespace {
12825f9e0466SNicolas Vasilache
12835f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass
12845f9e0466SNicolas Vasilache : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> {
12855f9e0466SNicolas Vasilache ConvertVectorToSCFPass() = default;
ConvertVectorToSCFPass__anon4d9edda11411::ConvertVectorToSCFPass12865f9e0466SNicolas Vasilache ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
12875f9e0466SNicolas Vasilache this->fullUnroll = options.unroll;
12882ca887deSMatthias Springer this->targetRank = options.targetRank;
1289fb7ec1f1SMatthias Springer this->lowerPermutationMaps = options.lowerPermutationMaps;
1290558e7401SMatthias Springer this->lowerTensors = options.lowerTensors;
12915f9e0466SNicolas Vasilache }
12925f9e0466SNicolas Vasilache
runOnOperation__anon4d9edda11411::ConvertVectorToSCFPass129341574554SRiver Riddle void runOnOperation() override {
12942ca887deSMatthias Springer VectorTransferToSCFOptions options;
1295fb7ec1f1SMatthias Springer options.unroll = fullUnroll;
1296fb7ec1f1SMatthias Springer options.targetRank = targetRank;
1297fb7ec1f1SMatthias Springer options.lowerPermutationMaps = lowerPermutationMaps;
1298558e7401SMatthias Springer options.lowerTensors = lowerTensors;
1299fb7ec1f1SMatthias Springer
1300fb7ec1f1SMatthias Springer // Lower permutation maps first.
1301fb7ec1f1SMatthias Springer if (lowerPermutationMaps) {
130247f175b0SRiver Riddle RewritePatternSet lowerTransferPatterns(&getContext());
1303fb7ec1f1SMatthias Springer mlir::vector::populateVectorTransferPermutationMapLoweringPatterns(
1304fb7ec1f1SMatthias Springer lowerTransferPatterns);
130541574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(),
1306fb7ec1f1SMatthias Springer std::move(lowerTransferPatterns));
1307fb7ec1f1SMatthias Springer }
13082ca887deSMatthias Springer
130947f175b0SRiver Riddle RewritePatternSet patterns(&getContext());
13102ca887deSMatthias Springer populateVectorToSCFConversionPatterns(patterns, options);
131141574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
13125f9e0466SNicolas Vasilache }
13135f9e0466SNicolas Vasilache };
13145f9e0466SNicolas Vasilache
13155f9e0466SNicolas Vasilache } // namespace
13165f9e0466SNicolas Vasilache
13175f9e0466SNicolas Vasilache std::unique_ptr<Pass>
createConvertVectorToSCFPass(const VectorTransferToSCFOptions & options)13185f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) {
13195f9e0466SNicolas Vasilache return std::make_unique<ConvertVectorToSCFPass>(options);
13205f9e0466SNicolas Vasilache }
1321