10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 24ead2cf7SAlex Zinenko // 34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64ead2cf7SAlex Zinenko // 74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 84ead2cf7SAlex Zinenko // 90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF. 104ead2cf7SAlex Zinenko // 114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 124ead2cf7SAlex Zinenko 134ead2cf7SAlex Zinenko #include <type_traits> 144ead2cf7SAlex Zinenko 154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 165f9e0466SNicolas Vasilache 175f9e0466SNicolas Vasilache #include "../PassDetail.h" 186825bfe2SNicolas Vasilache #include "mlir/Dialect/Affine/IR/AffineOps.h" 19a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 2066f878ceSMatthias Springer #include "mlir/Dialect/MemRef/IR/MemRef.h" 21*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h" 2299ef9eebSMatthias Springer #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" 234ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h" 246825bfe2SNicolas Vasilache #include "mlir/IR/ImplicitLocOpBuilder.h" 255f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h" 26b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 275f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h" 284ead2cf7SAlex Zinenko 294ead2cf7SAlex Zinenko using namespace mlir; 304ead2cf7SAlex Zinenko using vector::TransferReadOp; 314ead2cf7SAlex Zinenko using vector::TransferWriteOp; 324ead2cf7SAlex Zinenko 33350dadaaSBenjamin Kramer namespace { 340f241638SMatthias Springer 350f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering. 360f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__"; 370f241638SMatthias Springer 382ca887deSMatthias Springer /// Patterns that inherit from this struct have access to 392ca887deSMatthias Springer /// VectorTransferToSCFOptions. 402ca887deSMatthias Springer template <typename OpTy> 412ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 422ca887deSMatthias Springer explicit VectorToSCFPattern(MLIRContext *context, 432ca887deSMatthias Springer VectorTransferToSCFOptions opt) 442ca887deSMatthias Springer : OpRewritePattern<OpTy>(context), options(opt) {} 452ca887deSMatthias Springer 462ca887deSMatthias Springer VectorTransferToSCFOptions options; 472ca887deSMatthias Springer }; 480f241638SMatthias Springer 490f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source` 500f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion. 510f241638SMatthias Springer /// A return value of None indicates a broadcast. 520f241638SMatthias Springer template <typename OpTy> 530f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) { 54c537a943SNicolas Vasilache // TODO: support 0-d corner case. 55c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 567c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 570f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 580f241638SMatthias Springer return expr.getPosition(); 597c3c5b11SNicolas Vasilache } 600f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 610f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 620f241638SMatthias Springer return None; 630f241638SMatthias Springer } 640f241638SMatthias Springer 650f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This 660f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is 670f241638SMatthias Springer /// omitted. 680f241638SMatthias Springer template <typename OpTy> 696825bfe2SNicolas Vasilache static AffineMap unpackedPermutationMap(OpBuilder &b, OpTy xferOp) { 70c537a943SNicolas Vasilache // TODO: support 0-d corner case. 71c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 727c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 730f241638SMatthias Springer return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 746825bfe2SNicolas Vasilache b.getContext()); 750f241638SMatthias Springer } 760f241638SMatthias Springer 770f241638SMatthias Springer /// Calculate the indices for the new vector transfer op. 780f241638SMatthias Springer /// 790f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 800f241638SMatthias Springer /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 810f241638SMatthias Springer /// ^^^^^^ 820f241638SMatthias Springer /// `iv` is the iteration variable of the (new) surrounding loop. 830f241638SMatthias Springer template <typename OpTy> 846825bfe2SNicolas Vasilache static void getXferIndices(OpBuilder &b, OpTy xferOp, Value iv, 850f241638SMatthias Springer SmallVector<Value, 8> &indices) { 860f241638SMatthias Springer typename OpTy::Adaptor adaptor(xferOp); 870f241638SMatthias Springer // Corresponding memref dim of the vector dim that is unpacked. 880f241638SMatthias Springer auto dim = unpackedDim(xferOp); 897c38fd60SJacques Pienaar auto prevIndices = adaptor.getIndices(); 900f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 910f241638SMatthias Springer 926825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 930f241638SMatthias Springer bool isBroadcast = !dim.hasValue(); 940f241638SMatthias Springer if (!isBroadcast) { 956825bfe2SNicolas Vasilache AffineExpr d0, d1; 966825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 977c38fd60SJacques Pienaar Value offset = adaptor.getIndices()[dim.getValue()]; 986825bfe2SNicolas Vasilache indices[dim.getValue()] = 996825bfe2SNicolas Vasilache makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 1000f241638SMatthias Springer } 1010f241638SMatthias Springer } 1020f241638SMatthias Springer 1036825bfe2SNicolas Vasilache static void maybeYieldValue(OpBuilder &b, Location loc, bool hasRetVal, 1040f241638SMatthias Springer Value value) { 1050f241638SMatthias Springer if (hasRetVal) { 106558e7401SMatthias Springer assert(value && "Expected non-empty value"); 1076825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, value); 1080f241638SMatthias Springer } else { 1096825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 1100f241638SMatthias Springer } 1110f241638SMatthias Springer } 1120f241638SMatthias Springer 1130f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 1140f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances: 1150f241638SMatthias Springer /// * xferOp does not have a mask. 1160f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 1170f241638SMatthias Springer /// computed and attached to the new transfer op in the pattern.) 1180f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast. 1190f241638SMatthias Springer template <typename OpTy> 1206825bfe2SNicolas Vasilache static Value generateMaskCheck(OpBuilder &b, OpTy xferOp, Value iv) { 1217c38fd60SJacques Pienaar if (!xferOp.getMask()) 1220f241638SMatthias Springer return Value(); 1230f241638SMatthias Springer if (xferOp.getMaskType().getRank() != 1) 1240f241638SMatthias Springer return Value(); 1250f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) 1260f241638SMatthias Springer return Value(); 1270f241638SMatthias Springer 1286825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 1297c38fd60SJacques Pienaar return b.create<vector::ExtractElementOp>(loc, xferOp.getMask(), iv); 1300f241638SMatthias Springer } 1310f241638SMatthias Springer 1320f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion. 1330f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 1340f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`. 1350f241638SMatthias Springer /// E.g., when unpacking dimension 0 from: 1360f241638SMatthias Springer /// ``` 1370f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst 1380f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?xf32> 1390f241638SMatthias Springer /// ``` 1400f241638SMatthias Springer /// An if check similar to this will be generated inside the loop: 1410f241638SMatthias Springer /// ``` 1420f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32> 1430f241638SMatthias Springer /// if (%a + iv < %d) { 1440f241638SMatthias Springer /// (in-bounds case) 1450f241638SMatthias Springer /// } else { 1460f241638SMatthias Springer /// (out-of-bounds case) 1470f241638SMatthias Springer /// } 1480f241638SMatthias Springer /// ``` 1490f241638SMatthias Springer /// 1500f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex 1510f241638SMatthias Springer /// check also accounts for potentially masked out elements. 1520f241638SMatthias Springer /// 1530f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or 1540f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 1550f241638SMatthias Springer /// `resultTypes`. 1560f241638SMatthias Springer template <typename OpTy> 1570f241638SMatthias Springer static Value generateInBoundsCheck( 1586825bfe2SNicolas Vasilache OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 1590f241638SMatthias Springer TypeRange resultTypes, 1600f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> inBoundsCase, 1610f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 1620f241638SMatthias Springer bool hasRetVal = !resultTypes.empty(); 1630f241638SMatthias Springer Value cond; // Condition to be built... 1640f241638SMatthias Springer 1650f241638SMatthias Springer // Condition check 1: Access in-bounds? 1660f241638SMatthias Springer bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 1676825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 1686825bfe2SNicolas Vasilache ImplicitLocOpBuilder lb(xferOp.getLoc(), b); 1690f241638SMatthias Springer if (!xferOp.isDimInBounds(0) && !isBroadcast) { 1707c38fd60SJacques Pienaar Value memrefDim = 1717c38fd60SJacques Pienaar vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); 1726825bfe2SNicolas Vasilache AffineExpr d0, d1; 1736825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 1747c38fd60SJacques Pienaar Value base = xferOp.getIndices()[dim.getValue()]; 1756825bfe2SNicolas Vasilache Value memrefIdx = makeComposedAffineApply(b, loc, d0 + d1, {base, iv}); 176a54f4eaeSMogball cond = lb.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, memrefDim, 177a54f4eaeSMogball memrefIdx); 1780f241638SMatthias Springer } 1790f241638SMatthias Springer 1800f241638SMatthias Springer // Condition check 2: Masked in? 1816825bfe2SNicolas Vasilache if (auto maskCond = generateMaskCheck(b, xferOp, iv)) { 1826825bfe2SNicolas Vasilache if (cond) 183a54f4eaeSMogball cond = lb.create<arith::AndIOp>(cond, maskCond); 1846825bfe2SNicolas Vasilache else 1850f241638SMatthias Springer cond = maskCond; 1860f241638SMatthias Springer } 1870f241638SMatthias Springer 1880f241638SMatthias Springer // If the condition is non-empty, generate an SCF::IfOp. 1890f241638SMatthias Springer if (cond) { 1906825bfe2SNicolas Vasilache auto check = lb.create<scf::IfOp>( 1916825bfe2SNicolas Vasilache resultTypes, cond, 1920f241638SMatthias Springer /*thenBuilder=*/ 1936825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 1946825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, inBoundsCase(b, loc)); 195cadb7ccfSAlex Zinenko }, 1960f241638SMatthias Springer /*elseBuilder=*/ 1976825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 1980f241638SMatthias Springer if (outOfBoundsCase) { 1996825bfe2SNicolas Vasilache maybeYieldValue(b, loc, hasRetVal, outOfBoundsCase(b, loc)); 2007c3c5b11SNicolas Vasilache } else { 2016825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 2027c3c5b11SNicolas Vasilache } 2037c3c5b11SNicolas Vasilache }); 2047c3c5b11SNicolas Vasilache 2050f241638SMatthias Springer return hasRetVal ? check.getResult(0) : Value(); 2064ead2cf7SAlex Zinenko } 2074ead2cf7SAlex Zinenko 2080f241638SMatthias Springer // Condition is empty, no need for an SCF::IfOp. 2096825bfe2SNicolas Vasilache return inBoundsCase(b, loc); 2100f241638SMatthias Springer } 2110f241638SMatthias Springer 2120f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 2130f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value. 2140f241638SMatthias Springer template <typename OpTy> 2150f241638SMatthias Springer static void generateInBoundsCheck( 2166825bfe2SNicolas Vasilache OpBuilder &b, OpTy xferOp, Value iv, Optional<int64_t> dim, 2170f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> inBoundsCase, 2180f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 2190f241638SMatthias Springer generateInBoundsCheck( 2206825bfe2SNicolas Vasilache b, xferOp, iv, dim, /*resultTypes=*/TypeRange(), 2210f241638SMatthias Springer /*inBoundsCase=*/ 2226825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2236825bfe2SNicolas Vasilache inBoundsCase(b, loc); 2240f241638SMatthias Springer return Value(); 2250f241638SMatthias Springer }, 2260f241638SMatthias Springer /*outOfBoundsCase=*/ 2276825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 2280f241638SMatthias Springer if (outOfBoundsCase) 2296825bfe2SNicolas Vasilache outOfBoundsCase(b, loc); 2300f241638SMatthias Springer return Value(); 2310f241638SMatthias Springer }); 2320f241638SMatthias Springer } 2330f241638SMatthias Springer 2340f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped. 2356825bfe2SNicolas Vasilache static ArrayAttr dropFirstElem(OpBuilder &b, ArrayAttr attr) { 2360f241638SMatthias Springer if (!attr) 2370f241638SMatthias Springer return attr; 2386825bfe2SNicolas Vasilache return ArrayAttr::get(b.getContext(), attr.getValue().drop_front()); 2390f241638SMatthias Springer } 2400f241638SMatthias Springer 2410f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target 2420f241638SMatthias Springer /// rank. 2430f241638SMatthias Springer template <typename OpTy> 2446825bfe2SNicolas Vasilache static void maybeApplyPassLabel(OpBuilder &b, OpTy newXferOp, 2452ca887deSMatthias Springer unsigned targetRank) { 2462ca887deSMatthias Springer if (newXferOp.getVectorType().getRank() > targetRank) 2476825bfe2SNicolas Vasilache newXferOp->setAttr(kPassLabel, b.getUnitAttr()); 2480f241638SMatthias Springer } 2490f241638SMatthias Springer 250558e7401SMatthias Springer /// Return true if this transfer op operates on a source tensor. 251558e7401SMatthias Springer template <typename OpTy> 252558e7401SMatthias Springer static bool isTensorOp(OpTy xferOp) { 253558e7401SMatthias Springer if (xferOp.getShapedType().template isa<RankedTensorType>()) { 254558e7401SMatthias Springer if (xferOp.getOperationName().equals(TransferWriteOp::getOperationName())) { 255558e7401SMatthias Springer // TransferWriteOps on tensors have a result. 256558e7401SMatthias Springer assert(xferOp->getNumResults() > 0); 257558e7401SMatthias Springer } 258558e7401SMatthias Springer return true; 259558e7401SMatthias Springer } 260558e7401SMatthias Springer return false; 261558e7401SMatthias Springer } 262558e7401SMatthias Springer 263a088bed4SMatthias Springer namespace lowering_n_d { 264a088bed4SMatthias Springer 265a088bed4SMatthias Springer /// Helper data structure for data and mask buffers. 266a088bed4SMatthias Springer struct BufferAllocs { 267a088bed4SMatthias Springer Value dataBuffer; 268a088bed4SMatthias Springer Value maskBuffer; 269a088bed4SMatthias Springer }; 270a088bed4SMatthias Springer 2713c3810e7SNicolas Vasilache // TODO: Parallelism and threadlocal considerations with a ParallelScope trait. 2723c3810e7SNicolas Vasilache static Operation *getAutomaticAllocationScope(Operation *op) { 2733c3810e7SNicolas Vasilache Operation *scope = 2743c3810e7SNicolas Vasilache op->getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 2753c3810e7SNicolas Vasilache assert(scope && "Expected op to be inside automatic allocation scope"); 2763c3810e7SNicolas Vasilache return scope; 2773c3810e7SNicolas Vasilache } 2783c3810e7SNicolas Vasilache 279a088bed4SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present). 280a088bed4SMatthias Springer template <typename OpTy> 2816825bfe2SNicolas Vasilache static BufferAllocs allocBuffers(OpBuilder &b, OpTy xferOp) { 2826825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 283a088bed4SMatthias Springer OpBuilder::InsertionGuard guard(b); 2843c3810e7SNicolas Vasilache Operation *scope = getAutomaticAllocationScope(xferOp); 2853c3810e7SNicolas Vasilache assert(scope->getNumRegions() == 1 && 2863c3810e7SNicolas Vasilache "AutomaticAllocationScope with >1 regions"); 287a088bed4SMatthias Springer b.setInsertionPointToStart(&scope->getRegion(0).front()); 288a088bed4SMatthias Springer 289a088bed4SMatthias Springer BufferAllocs result; 290a088bed4SMatthias Springer auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 2916825bfe2SNicolas Vasilache result.dataBuffer = b.create<memref::AllocaOp>(loc, bufferType); 292a088bed4SMatthias Springer 2937c38fd60SJacques Pienaar if (xferOp.getMask()) { 2947c38fd60SJacques Pienaar auto maskType = MemRefType::get({}, xferOp.getMask().getType()); 2956825bfe2SNicolas Vasilache auto maskBuffer = b.create<memref::AllocaOp>(loc, maskType); 296fb7ec1f1SMatthias Springer b.setInsertionPoint(xferOp); 2977c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, xferOp.getMask(), maskBuffer); 2986825bfe2SNicolas Vasilache result.maskBuffer = b.create<memref::LoadOp>(loc, maskBuffer); 299a088bed4SMatthias Springer } 300a088bed4SMatthias Springer 301a088bed4SMatthias Springer return result; 302a088bed4SMatthias Springer } 303a088bed4SMatthias Springer 304a088bed4SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from 305a088bed4SMatthias Springer /// the VectorType into the MemRefType. 306a088bed4SMatthias Springer /// 307a088bed4SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 308a088bed4SMatthias Springer static MemRefType unpackOneDim(MemRefType type) { 309a088bed4SMatthias Springer auto vectorType = type.getElementType().dyn_cast<VectorType>(); 310a088bed4SMatthias Springer auto memrefShape = type.getShape(); 311a088bed4SMatthias Springer SmallVector<int64_t, 8> newMemrefShape; 312a088bed4SMatthias Springer newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 313a088bed4SMatthias Springer newMemrefShape.push_back(vectorType.getDimSize(0)); 314a088bed4SMatthias Springer return MemRefType::get(newMemrefShape, 315a088bed4SMatthias Springer VectorType::get(vectorType.getShape().drop_front(), 316a088bed4SMatthias Springer vectorType.getElementType())); 317a088bed4SMatthias Springer } 318a088bed4SMatthias Springer 3190f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This 3200f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer. 3210f241638SMatthias Springer template <typename OpTy> 3220f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) { 3237c38fd60SJacques Pienaar assert(xferOp.getMask() && "Expected that transfer op has mask"); 3247c38fd60SJacques Pienaar auto loadOp = xferOp.getMask().template getDefiningOp<memref::LoadOp>(); 3250f241638SMatthias Springer assert(loadOp && "Expected transfer op mask produced by LoadOp"); 3260f241638SMatthias Springer return loadOp.getMemRef(); 3270f241638SMatthias Springer } 3280f241638SMatthias Springer 3290f241638SMatthias Springer /// Codegen strategy, depending on the operation. 3300f241638SMatthias Springer template <typename OpTy> 3310f241638SMatthias Springer struct Strategy; 3320f241638SMatthias Springer 3330f241638SMatthias Springer /// Code strategy for vector TransferReadOp. 3344ead2cf7SAlex Zinenko template <> 3350f241638SMatthias Springer struct Strategy<TransferReadOp> { 3360f241638SMatthias Springer /// Find the StoreOp that is used for writing the current TransferReadOp's 3370f241638SMatthias Springer /// result to the temporary buffer allocation. 3380f241638SMatthias Springer static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 3390f241638SMatthias Springer assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 3400f241638SMatthias Springer auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 3410f241638SMatthias Springer assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 3420f241638SMatthias Springer return storeOp; 3437c3c5b11SNicolas Vasilache } 3444ead2cf7SAlex Zinenko 3450f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferReadOps are 3460f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 3470f241638SMatthias Springer /// of the buffer allocation: 3480f241638SMatthias Springer /// ``` 3490f241638SMatthias Springer /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 3500f241638SMatthias Springer /// memref.store %vec, %buf[...] ... 3510f241638SMatthias Springer /// ``` 3520f241638SMatthias Springer static Value getBuffer(TransferReadOp xferOp) { 3530f241638SMatthias Springer return getStoreOp(xferOp).getMemRef(); 3541870e787SNicolas Vasilache } 3550f241638SMatthias Springer 3560f241638SMatthias Springer /// Retrieve the indices of the current StoreOp that stores into the buffer. 3570f241638SMatthias Springer static void getBufferIndices(TransferReadOp xferOp, 3580f241638SMatthias Springer SmallVector<Value, 8> &indices) { 3590f241638SMatthias Springer auto storeOp = getStoreOp(xferOp); 3600f241638SMatthias Springer auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 3610f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 3620f241638SMatthias Springer } 3630f241638SMatthias Springer 3640f241638SMatthias Springer /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 3650f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 3660f241638SMatthias Springer /// 3670f241638SMatthias Springer /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 3680f241638SMatthias Springer /// variable `iv`. 3690f241638SMatthias Springer /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 3700f241638SMatthias Springer /// 3710f241638SMatthias Springer /// E.g.: 3720f241638SMatthias Springer /// ``` 3730f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 3740f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4x3xf32> 3750f241638SMatthias Springer /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 3760f241638SMatthias Springer /// ``` 3770f241638SMatthias Springer /// Is rewritten to: 3780f241638SMatthias Springer /// ``` 3790f241638SMatthias Springer /// %casted = vector.type_cast %buf 3800f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 3810f241638SMatthias Springer /// for %j = 0 to 4 { 3820f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 3830f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<3xf32> 3840f241638SMatthias Springer /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 3850f241638SMatthias Springer /// } 3860f241638SMatthias Springer /// ``` 3870f241638SMatthias Springer /// 3880f241638SMatthias Springer /// Note: The loop and type cast are generated in TransferOpConversion. 3890f241638SMatthias Springer /// The original TransferReadOp and store op are deleted in `cleanup`. 3900f241638SMatthias Springer /// Note: The `mask` operand is set in TransferOpConversion. 3916825bfe2SNicolas Vasilache static TransferReadOp rewriteOp(OpBuilder &b, 3922ca887deSMatthias Springer VectorTransferToSCFOptions options, 393558e7401SMatthias Springer TransferReadOp xferOp, Value buffer, Value iv, 394558e7401SMatthias Springer ValueRange /*loopState*/) { 3950f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 3960f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 3970f241638SMatthias Springer storeIndices.push_back(iv); 3980f241638SMatthias Springer 3990f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 4006825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 4010f241638SMatthias Springer 4026825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 4030f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 4040f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 4057c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 4066825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>( 4077c38fd60SJacques Pienaar loc, vecType, xferOp.getSource(), xferIndices, 4087c38fd60SJacques Pienaar AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 4097c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr); 4100f241638SMatthias Springer 4116825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank); 4120f241638SMatthias Springer 4137c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, newXferOp.getVector(), buffer, storeIndices); 4146825bfe2SNicolas Vasilache return newXferOp; 4150f241638SMatthias Springer } 4160f241638SMatthias Springer 4170f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 4180f241638SMatthias Springer /// padding value to the temporary buffer. 419558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferReadOp xferOp, 420558e7401SMatthias Springer Value buffer, Value iv, 421558e7401SMatthias Springer ValueRange /*loopState*/) { 4220f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 4230f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 4240f241638SMatthias Springer storeIndices.push_back(iv); 4250f241638SMatthias Springer 4266825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 4270f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 4280f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 4297c38fd60SJacques Pienaar auto vec = b.create<vector::SplatOp>(loc, vecType, xferOp.getPadding()); 4306825bfe2SNicolas Vasilache b.create<memref::StoreOp>(loc, vec, buffer, storeIndices); 431558e7401SMatthias Springer 432558e7401SMatthias Springer return Value(); 4330f241638SMatthias Springer } 4340f241638SMatthias Springer 4350f241638SMatthias Springer /// Cleanup after rewriting the op. 436558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp, 437558e7401SMatthias Springer scf::ForOp /*forOp*/) { 4380f241638SMatthias Springer rewriter.eraseOp(getStoreOp(xferOp)); 4390f241638SMatthias Springer rewriter.eraseOp(xferOp); 4400f241638SMatthias Springer } 441558e7401SMatthias Springer 442558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop. 443558e7401SMatthias Springer static Value initialLoopState(TransferReadOp xferOp) { return Value(); } 4444ead2cf7SAlex Zinenko }; 4457c3c5b11SNicolas Vasilache 4460f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp. 4470f241638SMatthias Springer template <> 4480f241638SMatthias Springer struct Strategy<TransferWriteOp> { 4490f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferWriteOps are 4500f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 4510f241638SMatthias Springer /// of the buffer allocation: 4520f241638SMatthias Springer /// ``` 4530f241638SMatthias Springer /// %vec = memref.load %buf[...] ... 4540f241638SMatthias Springer /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 4550f241638SMatthias Springer /// ``` 4560f241638SMatthias Springer static Value getBuffer(TransferWriteOp xferOp) { 4577c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 4580f241638SMatthias Springer assert(loadOp && "Expected transfer op vector produced by LoadOp"); 4590f241638SMatthias Springer return loadOp.getMemRef(); 4607c3c5b11SNicolas Vasilache } 4614ead2cf7SAlex Zinenko 4620f241638SMatthias Springer /// Retrieve the indices of the current LoadOp that loads from the buffer. 4630f241638SMatthias Springer static void getBufferIndices(TransferWriteOp xferOp, 4640f241638SMatthias Springer SmallVector<Value, 8> &indices) { 4657c38fd60SJacques Pienaar auto loadOp = xferOp.getVector().getDefiningOp<memref::LoadOp>(); 4660f241638SMatthias Springer auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 4670f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 4680f241638SMatthias Springer } 4690f241638SMatthias Springer 4700f241638SMatthias Springer /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 4710f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 4720f241638SMatthias Springer /// 4730f241638SMatthias Springer /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 4740f241638SMatthias Springer /// using the loop iteration variable `iv`. 4750f241638SMatthias Springer /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 4760f241638SMatthias Springer /// to memory. 4770f241638SMatthias Springer /// 4780f241638SMatthias Springer /// Note: For more details, see comments on Strategy<TransferReadOp>. 4796825bfe2SNicolas Vasilache static TransferWriteOp rewriteOp(OpBuilder &b, 4802ca887deSMatthias Springer VectorTransferToSCFOptions options, 4812ca887deSMatthias Springer TransferWriteOp xferOp, Value buffer, 482558e7401SMatthias Springer Value iv, ValueRange loopState) { 4830f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 4840f241638SMatthias Springer getBufferIndices(xferOp, loadIndices); 4850f241638SMatthias Springer loadIndices.push_back(iv); 4860f241638SMatthias Springer 4870f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 4886825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 4890f241638SMatthias Springer 4906825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 4916825bfe2SNicolas Vasilache auto vec = b.create<memref::LoadOp>(loc, buffer, loadIndices); 4927c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 4937c38fd60SJacques Pienaar auto source = loopState.empty() ? xferOp.getSource() : loopState[0]; 494558e7401SMatthias Springer Type type = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 4956825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>( 496558e7401SMatthias Springer loc, type, vec, source, xferIndices, 4976825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 4980f241638SMatthias Springer inBoundsAttr); 4990f241638SMatthias Springer 5006825bfe2SNicolas Vasilache maybeApplyPassLabel(b, newXferOp, options.targetRank); 5010f241638SMatthias Springer 5026825bfe2SNicolas Vasilache return newXferOp; 5030f241638SMatthias Springer } 5040f241638SMatthias Springer 5050f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 506558e7401SMatthias Springer static Value handleOutOfBoundsDim(OpBuilder &b, TransferWriteOp xferOp, 507558e7401SMatthias Springer Value buffer, Value iv, 508558e7401SMatthias Springer ValueRange loopState) { 509558e7401SMatthias Springer return isTensorOp(xferOp) ? loopState[0] : Value(); 510558e7401SMatthias Springer } 5110f241638SMatthias Springer 5120f241638SMatthias Springer /// Cleanup after rewriting the op. 513558e7401SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp, 514558e7401SMatthias Springer scf::ForOp forOp) { 515558e7401SMatthias Springer if (isTensorOp(xferOp)) { 516558e7401SMatthias Springer assert(forOp->getNumResults() == 1 && "Expected one for loop result"); 517558e7401SMatthias Springer rewriter.replaceOp(xferOp, forOp->getResult(0)); 518558e7401SMatthias Springer } else { 5190f241638SMatthias Springer rewriter.eraseOp(xferOp); 5200f241638SMatthias Springer } 521558e7401SMatthias Springer } 522558e7401SMatthias Springer 523558e7401SMatthias Springer /// Return the initial loop state for the generated scf.for loop. 524558e7401SMatthias Springer static Value initialLoopState(TransferWriteOp xferOp) { 5257c38fd60SJacques Pienaar return isTensorOp(xferOp) ? xferOp.getSource() : Value(); 526558e7401SMatthias Springer } 5270f241638SMatthias Springer }; 5280f241638SMatthias Springer 5290f241638SMatthias Springer template <typename OpTy> 530fb7ec1f1SMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp, 531fb7ec1f1SMatthias Springer VectorTransferToSCFOptions options) { 5320f241638SMatthias Springer if (xferOp->hasAttr(kPassLabel)) 5330f241638SMatthias Springer return failure(); 534fb7ec1f1SMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 5350f241638SMatthias Springer return failure(); 536558e7401SMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors) 5378fb48979SMatthias Springer return failure(); 538f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 539f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() != 540f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 541f718a53dSMatthias Springer return failure(); 5420f241638SMatthias Springer return success(); 5430f241638SMatthias Springer } 5440f241638SMatthias Springer 5450f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering. 5460f241638SMatthias Springer /// 5470f241638SMatthias Springer /// 1. Allocate a temporary buffer. 5480f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 5490f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer. 5500f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the 5510f241638SMatthias Springer /// original TransferReadOp with this load. 5520f241638SMatthias Springer /// 5530f241638SMatthias Springer /// E.g.: 5540f241638SMatthias Springer /// ``` 5550f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 5560f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 5570f241638SMatthias Springer /// ``` 5580f241638SMatthias Springer /// is rewritten to: 5590f241638SMatthias Springer /// ``` 5600f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 5610f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 5620f241638SMatthias Springer /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 5630f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>> 5640f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 5650f241638SMatthias Springer /// ``` 5660f241638SMatthias Springer /// 5670f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 5682ca887deSMatthias Springer struct PrepareTransferReadConversion 5692ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 5702ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 5710f241638SMatthias Springer 5720f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 5730f241638SMatthias Springer PatternRewriter &rewriter) const override { 574fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed()) 5750f241638SMatthias Springer return failure(); 5760f241638SMatthias Springer 5776825bfe2SNicolas Vasilache auto buffers = allocBuffers(rewriter, xferOp); 5780f241638SMatthias Springer auto *newXfer = rewriter.clone(*xferOp.getOperation()); 5790f241638SMatthias Springer newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 5807c38fd60SJacques Pienaar if (xferOp.getMask()) { 5817c38fd60SJacques Pienaar dyn_cast<TransferReadOp>(newXfer).getMaskMutable().assign( 5820f241638SMatthias Springer buffers.maskBuffer); 5830f241638SMatthias Springer } 5840f241638SMatthias Springer 5856825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 5866825bfe2SNicolas Vasilache rewriter.create<memref::StoreOp>(loc, newXfer->getResult(0), 5876825bfe2SNicolas Vasilache buffers.dataBuffer); 5880f241638SMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 5894ead2cf7SAlex Zinenko 5904ead2cf7SAlex Zinenko return success(); 5914ead2cf7SAlex Zinenko } 5920f241638SMatthias Springer }; 5930f241638SMatthias Springer 5940f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering. 5950f241638SMatthias Springer /// 5960f241638SMatthias Springer /// 1. Allocate a temporary buffer. 5970f241638SMatthias Springer /// 2. Store the vector into the buffer. 5980f241638SMatthias Springer /// 3. Load the vector from the buffer again. 5990f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 6000f241638SMatthias Springer /// marking it eligible for progressive lowering via TransferOpConversion. 6010f241638SMatthias Springer /// 6020f241638SMatthias Springer /// E.g.: 6030f241638SMatthias Springer /// ``` 6040f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 6050f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 6060f241638SMatthias Springer /// ``` 6070f241638SMatthias Springer /// is rewritten to: 6080f241638SMatthias Springer /// ``` 6090f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 6100f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 6110f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 6120f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 6130f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 6140f241638SMatthias Springer /// ``` 6150f241638SMatthias Springer /// 6160f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 6170f241638SMatthias Springer struct PrepareTransferWriteConversion 6182ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 6192ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 6200f241638SMatthias Springer 6210f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 6220f241638SMatthias Springer PatternRewriter &rewriter) const override { 623fb7ec1f1SMatthias Springer if (checkPrepareXferOp(xferOp, options).failed()) 6240f241638SMatthias Springer return failure(); 6250f241638SMatthias Springer 6266825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 6276825bfe2SNicolas Vasilache auto buffers = allocBuffers(rewriter, xferOp); 6287c38fd60SJacques Pienaar rewriter.create<memref::StoreOp>(loc, xferOp.getVector(), 6297c38fd60SJacques Pienaar buffers.dataBuffer); 6306825bfe2SNicolas Vasilache auto loadedVec = rewriter.create<memref::LoadOp>(loc, buffers.dataBuffer); 6310f241638SMatthias Springer rewriter.updateRootInPlace(xferOp, [&]() { 6327c38fd60SJacques Pienaar xferOp.getVectorMutable().assign(loadedVec); 6330f241638SMatthias Springer xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 6340f241638SMatthias Springer }); 6350f241638SMatthias Springer 6367c38fd60SJacques Pienaar if (xferOp.getMask()) { 6377c38fd60SJacques Pienaar rewriter.updateRootInPlace(xferOp, [&]() { 6387c38fd60SJacques Pienaar xferOp.getMaskMutable().assign(buffers.maskBuffer); 6397c38fd60SJacques Pienaar }); 6400f241638SMatthias Springer } 6410f241638SMatthias Springer 6420f241638SMatthias Springer return success(); 6430f241638SMatthias Springer } 6440f241638SMatthias Springer }; 6450f241638SMatthias Springer 6460f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension. 6470f241638SMatthias Springer /// 6480f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer 6490f241638SMatthias Springer /// to that new type. E.g.: 6500f241638SMatthias Springer /// ``` 6510f241638SMatthias Springer /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 6520f241638SMatthias Springer /// vector.transfer_write %vec ... 6530f241638SMatthias Springer /// ``` 6540f241638SMatthias Springer /// The following cast is generated: 6550f241638SMatthias Springer /// ``` 6560f241638SMatthias Springer /// %casted = vector.type_cast %0 6570f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 6580f241638SMatthias Springer /// ``` 6590f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the 6600f241638SMatthias Springer /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 6610f241638SMatthias Springer /// out-of-bounds, generate an if-check and handle both cases separately. 6620f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>. 663558e7401SMatthias Springer /// 664558e7401SMatthias Springer /// Note: If the transfer op is a TransferWriteOp and operates on a tensor 665558e7401SMatthias Springer /// source (as opposed to a memref source), then each iteration of the generated 666558e7401SMatthias Springer /// scf.for loop yields the new tensor value. E.g.: 667558e7401SMatthias Springer /// ``` 668558e7401SMatthias Springer /// %result = scf.for i = 0 to 5 { 669558e7401SMatthias Springer /// %0 = memref.load %buffer[i] : memref<5xvector<4x3xf32>> 670558e7401SMatthias Springer /// %1 = vector.transfer_write %0, %source[...] 671558e7401SMatthias Springer /// : vector<4x3xf32>, tensor<5x4x3xf32> 672558e7401SMatthias Springer /// scf.yield %1 : tensor<5x4x3xf32> 673558e7401SMatthias Springer /// } 674558e7401SMatthias Springer /// ``` 6750f241638SMatthias Springer template <typename OpTy> 6762ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 6772ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 6780f241638SMatthias Springer 679700b64dcSMatthias Springer void initialize() { 680700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 681700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 682700b64dcSMatthias Springer this->setHasBoundedRewriteRecursion(); 683700b64dcSMatthias Springer } 684700b64dcSMatthias Springer 6850f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 6860f241638SMatthias Springer PatternRewriter &rewriter) const override { 6870f241638SMatthias Springer if (!xferOp->hasAttr(kPassLabel)) 6880f241638SMatthias Springer return failure(); 6890f241638SMatthias Springer 6900f241638SMatthias Springer // Find and cast data buffer. How the buffer can be found depends on OpTy. 6916825bfe2SNicolas Vasilache ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter); 6920f241638SMatthias Springer auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 6930f241638SMatthias Springer auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 6940f241638SMatthias Springer auto castedDataType = unpackOneDim(dataBufferType); 6956825bfe2SNicolas Vasilache auto castedDataBuffer = 6966825bfe2SNicolas Vasilache locB.create<vector::TypeCastOp>(castedDataType, dataBuffer); 6970f241638SMatthias Springer 6980f241638SMatthias Springer // If the xferOp has a mask: Find and cast mask buffer. 6990f241638SMatthias Springer Value castedMaskBuffer; 7007c38fd60SJacques Pienaar if (xferOp.getMask()) { 7010f241638SMatthias Springer auto maskBuffer = getMaskBuffer(xferOp); 7020f241638SMatthias Springer auto maskBufferType = 7030f241638SMatthias Springer maskBuffer.getType().template dyn_cast<MemRefType>(); 7040f241638SMatthias Springer if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 7050f241638SMatthias Springer // Do not unpack a dimension of the mask, if: 7060f241638SMatthias Springer // * To-be-unpacked transfer op dimension is a broadcast. 7070f241638SMatthias Springer // * Mask is 1D, i.e., the mask cannot be further unpacked. 7080f241638SMatthias Springer // (That means that all remaining dimensions of the transfer op must 7090f241638SMatthias Springer // be broadcasted.) 7100f241638SMatthias Springer castedMaskBuffer = maskBuffer; 7110f241638SMatthias Springer } else { 7120f241638SMatthias Springer auto castedMaskType = unpackOneDim(maskBufferType); 7136825bfe2SNicolas Vasilache castedMaskBuffer = 7146825bfe2SNicolas Vasilache locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer); 7150f241638SMatthias Springer } 7160f241638SMatthias Springer } 7170f241638SMatthias Springer 7180f241638SMatthias Springer // Loop bounds and step. 719a54f4eaeSMogball auto lb = locB.create<arith::ConstantIndexOp>(0); 720a54f4eaeSMogball auto ub = locB.create<arith::ConstantIndexOp>( 7216825bfe2SNicolas Vasilache castedDataType.getDimSize(castedDataType.getRank() - 1)); 722a54f4eaeSMogball auto step = locB.create<arith::ConstantIndexOp>(1); 723558e7401SMatthias Springer // TransferWriteOps that operate on tensors return the modified tensor and 724558e7401SMatthias Springer // require a loop state. 725558e7401SMatthias Springer auto loopState = Strategy<OpTy>::initialLoopState(xferOp); 7260f241638SMatthias Springer 7270f241638SMatthias Springer // Generate for loop. 728558e7401SMatthias Springer auto result = locB.create<scf::ForOp>( 729558e7401SMatthias Springer lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 730558e7401SMatthias Springer [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 731558e7401SMatthias Springer Type stateType = loopState.empty() ? Type() : loopState[0].getType(); 732558e7401SMatthias Springer 733558e7401SMatthias Springer auto result = generateInBoundsCheck( 7346825bfe2SNicolas Vasilache b, xferOp, iv, unpackedDim(xferOp), 735558e7401SMatthias Springer stateType ? TypeRange(stateType) : TypeRange(), 7360f241638SMatthias Springer /*inBoundsCase=*/ 7376825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 7380f241638SMatthias Springer // Create new transfer op. 7392ca887deSMatthias Springer OpTy newXfer = Strategy<OpTy>::rewriteOp( 740558e7401SMatthias Springer b, this->options, xferOp, castedDataBuffer, iv, loopState); 7410f241638SMatthias Springer 7420f241638SMatthias Springer // If old transfer op has a mask: Set mask on new transfer op. 7430f241638SMatthias Springer // Special case: If the mask of the old transfer op is 1D and 7440f241638SMatthias Springer // the 7450f241638SMatthias Springer // unpacked dim is not a broadcast, no mask is 7460f241638SMatthias Springer // needed on the new transfer op. 7477c38fd60SJacques Pienaar if (xferOp.getMask() && (xferOp.isBroadcastDim(0) || 7480f241638SMatthias Springer xferOp.getMaskType().getRank() > 1)) { 7490f241638SMatthias Springer OpBuilder::InsertionGuard guard(b); 7500f241638SMatthias Springer b.setInsertionPoint(newXfer); // Insert load before newXfer. 7510f241638SMatthias Springer 7520f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 7530f241638SMatthias Springer Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 7540f241638SMatthias Springer // In case of broadcast: Use same indices to load from memref 7550f241638SMatthias Springer // as before. 7560f241638SMatthias Springer if (!xferOp.isBroadcastDim(0)) 7570f241638SMatthias Springer loadIndices.push_back(iv); 7580f241638SMatthias Springer 7596825bfe2SNicolas Vasilache auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer, 7606825bfe2SNicolas Vasilache loadIndices); 7617c38fd60SJacques Pienaar rewriter.updateRootInPlace(newXfer, [&]() { 7627c38fd60SJacques Pienaar newXfer.getMaskMutable().assign(mask); 7637c38fd60SJacques Pienaar }); 7640f241638SMatthias Springer } 765558e7401SMatthias Springer 766558e7401SMatthias Springer return loopState.empty() ? Value() : newXfer->getResult(0); 7670f241638SMatthias Springer }, 7680f241638SMatthias Springer /*outOfBoundsCase=*/ 7690f241638SMatthias Springer [&](OpBuilder &b, Location /*loc*/) { 770558e7401SMatthias Springer return Strategy<OpTy>::handleOutOfBoundsDim( 771558e7401SMatthias Springer b, xferOp, castedDataBuffer, iv, loopState); 7720f241638SMatthias Springer }); 7730f241638SMatthias Springer 774558e7401SMatthias Springer maybeYieldValue(b, loc, !loopState.empty(), result); 775558e7401SMatthias Springer }); 776558e7401SMatthias Springer 777558e7401SMatthias Springer Strategy<OpTy>::cleanup(rewriter, xferOp, result); 7780f241638SMatthias Springer return success(); 7790f241638SMatthias Springer } 7800f241638SMatthias Springer }; 7810f241638SMatthias Springer 782a088bed4SMatthias Springer } // namespace lowering_n_d 783a088bed4SMatthias Springer 784a088bed4SMatthias Springer namespace lowering_n_d_unrolled { 785a088bed4SMatthias Springer 7860f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer 7870f241638SMatthias Springer /// op (for the current iteration `i`) and assign it. 7880f241638SMatthias Springer template <typename OpTy> 7896825bfe2SNicolas Vasilache static void maybeAssignMask(OpBuilder &b, OpTy xferOp, OpTy newXferOp, 7900f241638SMatthias Springer int64_t i) { 7917c38fd60SJacques Pienaar if (!xferOp.getMask()) 7920f241638SMatthias Springer return; 7930f241638SMatthias Springer 7940f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) { 7950f241638SMatthias Springer // To-be-unpacked dimension is a broadcast, which does not have a 7960f241638SMatthias Springer // corresponding mask dimension. Mask attribute remains unchanged. 7977c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(xferOp.getMask()); 7980f241638SMatthias Springer return; 7990f241638SMatthias Springer } 8000f241638SMatthias Springer 8010f241638SMatthias Springer if (xferOp.getMaskType().getRank() > 1) { 8020f241638SMatthias Springer // Unpack one dimension of the mask. 8036825bfe2SNicolas Vasilache OpBuilder::InsertionGuard guard(b); 8046825bfe2SNicolas Vasilache b.setInsertionPoint(newXferOp); // Insert load before newXfer. 8050f241638SMatthias Springer 8060f241638SMatthias Springer llvm::SmallVector<int64_t, 1> indices({i}); 8076825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 8087c38fd60SJacques Pienaar auto newMask = b.create<vector::ExtractOp>(loc, xferOp.getMask(), indices); 8097c38fd60SJacques Pienaar newXferOp.getMaskMutable().assign(newMask); 8100f241638SMatthias Springer } 8110f241638SMatthias Springer 8120f241638SMatthias Springer // If we end up here: The mask of the old transfer op is 1D and the unpacked 8130f241638SMatthias Springer // dim is not a broadcast, so no mask is needed on the new transfer op. 8140f241638SMatthias Springer // `generateInBoundsCheck` will have evaluated the mask already. 8150f241638SMatthias Springer } 8160f241638SMatthias Springer 8170f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 8180f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 8190f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 8200f241638SMatthias Springer /// 8210f241638SMatthias Springer /// ``` 8220f241638SMatthias Springer /// E.g.: 8230f241638SMatthias Springer /// ``` 8240f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 8250f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<5x4xf32> 8260f241638SMatthias Springer /// ``` 8270f241638SMatthias Springer /// is rewritten to IR such as (simplified): 8280f241638SMatthias Springer /// ``` 8290f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32> 8300f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 8310f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 8320f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 8330f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 8340f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 8350f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 8360f241638SMatthias Springer /// ... 8370f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 8380f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 8390f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 8400f241638SMatthias Springer /// ``` 8410f241638SMatthias Springer /// 8420f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp 8430f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created. 8440f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector. 8452ca887deSMatthias Springer struct UnrollTransferReadConversion 8462ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 8472ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 8480f241638SMatthias Springer 849700b64dcSMatthias Springer void initialize() { 850700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 851700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 852700b64dcSMatthias Springer setHasBoundedRewriteRecursion(); 853700b64dcSMatthias Springer } 854700b64dcSMatthias Springer 8550f241638SMatthias Springer /// Return the vector into which the newly created TransferReadOp results 8560f241638SMatthias Springer /// are inserted. 8570f241638SMatthias Springer Value getResultVector(TransferReadOp xferOp, 8580f241638SMatthias Springer PatternRewriter &rewriter) const { 8590f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) 8607c38fd60SJacques Pienaar return insertOp.getDest(); 8616825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 8626a8ba318SRiver Riddle return rewriter.create<vector::SplatOp>(loc, xferOp.getVectorType(), 8637c38fd60SJacques Pienaar xferOp.getPadding()); 8640f241638SMatthias Springer } 8650f241638SMatthias Springer 8660f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 8670f241638SMatthias Springer /// vector::InsertOp, return that operation. 8680f241638SMatthias Springer vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 8690f241638SMatthias Springer if (xferOp->hasOneUse()) { 8700f241638SMatthias Springer Operation *xferOpUser = *xferOp->getUsers().begin(); 8710f241638SMatthias Springer if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 8720f241638SMatthias Springer return insertOp; 8730f241638SMatthias Springer } 8740f241638SMatthias Springer 8750f241638SMatthias Springer return vector::InsertOp(); 8760f241638SMatthias Springer } 8770f241638SMatthias Springer 8780f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 8790f241638SMatthias Springer /// vector::InsertOp, return that operation's indices. 8800f241638SMatthias Springer void getInsertionIndices(TransferReadOp xferOp, 8810f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const { 8820f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) { 8837c38fd60SJacques Pienaar llvm::for_each(insertOp.getPosition(), [&](Attribute attr) { 8840f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 8850f241638SMatthias Springer }); 8860f241638SMatthias Springer } 8870f241638SMatthias Springer } 8880f241638SMatthias Springer 8890f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 8900f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 8910f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 8920f241638SMatthias Springer PatternRewriter &rewriter) const override { 8932ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 8940f241638SMatthias Springer return failure(); 895bd20756dSMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors) 8968fb48979SMatthias Springer return failure(); 897f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 898f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() != 899f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 900f718a53dSMatthias Springer return failure(); 9010f241638SMatthias Springer 9020f241638SMatthias Springer auto insertOp = getInsertOp(xferOp); 9030f241638SMatthias Springer auto vec = getResultVector(xferOp, rewriter); 9040f241638SMatthias Springer auto vecType = vec.getType().dyn_cast<VectorType>(); 9050f241638SMatthias Springer auto xferVecType = xferOp.getVectorType(); 9060f241638SMatthias Springer auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 9070f241638SMatthias Springer xferVecType.getElementType()); 9080f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0]; 9090f241638SMatthias Springer 9100f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 9116825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 9120f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 913a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 9140f241638SMatthias Springer 9150f241638SMatthias Springer vec = generateInBoundsCheck( 9166825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp), TypeRange(vecType), 9170f241638SMatthias Springer /*inBoundsCase=*/ 9180f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 9190f241638SMatthias Springer // Indices for the new transfer op. 9200f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 9216825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 9220f241638SMatthias Springer 9230f241638SMatthias Springer // Indices for the new vector.insert op. 9240f241638SMatthias Springer SmallVector<int64_t, 8> insertionIndices; 9250f241638SMatthias Springer getInsertionIndices(xferOp, insertionIndices); 9260f241638SMatthias Springer insertionIndices.push_back(i); 9270f241638SMatthias Springer 9287c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 9296825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferReadOp>( 9307c38fd60SJacques Pienaar loc, newXferVecType, xferOp.getSource(), xferIndices, 9316825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), 9327c38fd60SJacques Pienaar xferOp.getPadding(), Value(), inBoundsAttr); 9330f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 9346825bfe2SNicolas Vasilache return b.create<vector::InsertOp>(loc, newXferOp, vec, 9356825bfe2SNicolas Vasilache insertionIndices); 9360f241638SMatthias Springer }, 9370f241638SMatthias Springer /*outOfBoundsCase=*/ 9380f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 9390f241638SMatthias Springer // Loop through original (unmodified) vector. 9400f241638SMatthias Springer return vec; 9410f241638SMatthias Springer }); 9420f241638SMatthias Springer } 9430f241638SMatthias Springer 9440f241638SMatthias Springer if (insertOp) { 9450f241638SMatthias Springer // Rewrite single user of the old TransferReadOp, which was an InsertOp. 9460f241638SMatthias Springer rewriter.replaceOp(insertOp, vec); 9470f241638SMatthias Springer rewriter.eraseOp(xferOp); 9480f241638SMatthias Springer } else { 9490f241638SMatthias Springer rewriter.replaceOp(xferOp, vec); 9500f241638SMatthias Springer } 9510f241638SMatthias Springer 9520f241638SMatthias Springer return success(); 9530f241638SMatthias Springer } 9540f241638SMatthias Springer }; 9550f241638SMatthias Springer 9560f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 9570f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 9580f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 9590f241638SMatthias Springer /// 9600f241638SMatthias Springer /// ``` 9610f241638SMatthias Springer /// E.g.: 9620f241638SMatthias Springer /// ``` 9630f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 9640f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 9650f241638SMatthias Springer /// ``` 9660f241638SMatthias Springer /// is rewritten to IR such as (simplified): 9670f241638SMatthias Springer /// ``` 9680f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 9690f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 9700f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 9710f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 9720f241638SMatthias Springer /// ... 9730f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 9740f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 9750f241638SMatthias Springer /// ``` 9760f241638SMatthias Springer /// 9770f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp 9780f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract 9790f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 9800f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during 9810f241638SMatthias Springer /// recursive application of this pattern will be minimal. 9820f241638SMatthias Springer struct UnrollTransferWriteConversion 9832ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 9842ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 9850f241638SMatthias Springer 986700b64dcSMatthias Springer void initialize() { 987700b64dcSMatthias Springer // This pattern recursively unpacks one dimension at a time. The recursion 988700b64dcSMatthias Springer // bounded as the rank is strictly decreasing. 989700b64dcSMatthias Springer setHasBoundedRewriteRecursion(); 990700b64dcSMatthias Springer } 991700b64dcSMatthias Springer 9920f241638SMatthias Springer /// Return the vector from which newly generated ExtracOps will extract. 9930f241638SMatthias Springer Value getDataVector(TransferWriteOp xferOp) const { 9940f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) 9957c38fd60SJacques Pienaar return extractOp.getVector(); 9967c38fd60SJacques Pienaar return xferOp.getVector(); 9970f241638SMatthias Springer } 9980f241638SMatthias Springer 9990f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return it. 10000f241638SMatthias Springer vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 10017c38fd60SJacques Pienaar if (auto *op = xferOp.getVector().getDefiningOp()) 10020f241638SMatthias Springer return dyn_cast<vector::ExtractOp>(op); 10030f241638SMatthias Springer return vector::ExtractOp(); 10040f241638SMatthias Springer } 10050f241638SMatthias Springer 10060f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return its 10070f241638SMatthias Springer /// indices. 10080f241638SMatthias Springer void getExtractionIndices(TransferWriteOp xferOp, 10090f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const { 10100f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) { 10117c38fd60SJacques Pienaar llvm::for_each(extractOp.getPosition(), [&](Attribute attr) { 10120f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 10130f241638SMatthias Springer }); 10140f241638SMatthias Springer } 10150f241638SMatthias Springer } 10160f241638SMatthias Springer 10170f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 10180f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 10190f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 10200f241638SMatthias Springer PatternRewriter &rewriter) const override { 10212ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 10220f241638SMatthias Springer return failure(); 1023bd20756dSMatthias Springer if (isTensorOp(xferOp) && !options.lowerTensors) 10248fb48979SMatthias Springer return failure(); 1025f718a53dSMatthias Springer // Transfer ops that modify the element type are not supported atm. 1026f718a53dSMatthias Springer if (xferOp.getVectorType().getElementType() != 1027f718a53dSMatthias Springer xferOp.getShapedType().getElementType()) 1028f718a53dSMatthias Springer return failure(); 10290f241638SMatthias Springer 10300f241638SMatthias Springer auto vec = getDataVector(xferOp); 10310f241638SMatthias Springer auto xferVecType = xferOp.getVectorType(); 10320f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0]; 10337c38fd60SJacques Pienaar auto source = xferOp.getSource(); // memref or tensor to be written to. 1034bd20756dSMatthias Springer auto sourceType = isTensorOp(xferOp) ? xferOp.getShapedType() : Type(); 10350f241638SMatthias Springer 10360f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 10376825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 10380f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 1039a54f4eaeSMogball Value iv = rewriter.create<arith::ConstantIndexOp>(loc, i); 10400f241638SMatthias Springer 1041bd20756dSMatthias Springer auto updatedSource = generateInBoundsCheck( 10426825bfe2SNicolas Vasilache rewriter, xferOp, iv, unpackedDim(xferOp), 1043bd20756dSMatthias Springer isTensorOp(xferOp) ? TypeRange(sourceType) : TypeRange(), 1044bd20756dSMatthias Springer /*inBoundsCase=*/ 1045bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) { 10460f241638SMatthias Springer // Indices for the new transfer op. 10470f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 10486825bfe2SNicolas Vasilache getXferIndices(b, xferOp, iv, xferIndices); 10490f241638SMatthias Springer 10500f241638SMatthias Springer // Indices for the new vector.extract op. 10510f241638SMatthias Springer SmallVector<int64_t, 8> extractionIndices; 10520f241638SMatthias Springer getExtractionIndices(xferOp, extractionIndices); 10530f241638SMatthias Springer extractionIndices.push_back(i); 10540f241638SMatthias Springer 10556825bfe2SNicolas Vasilache auto extracted = 10566825bfe2SNicolas Vasilache b.create<vector::ExtractOp>(loc, vec, extractionIndices); 10577c38fd60SJacques Pienaar auto inBoundsAttr = dropFirstElem(b, xferOp.getInBoundsAttr()); 10586825bfe2SNicolas Vasilache auto newXferOp = b.create<vector::TransferWriteOp>( 1059bd20756dSMatthias Springer loc, sourceType, extracted, source, xferIndices, 10606825bfe2SNicolas Vasilache AffineMapAttr::get(unpackedPermutationMap(b, xferOp)), Value(), 10616825bfe2SNicolas Vasilache inBoundsAttr); 10620f241638SMatthias Springer 10630f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 1064bd20756dSMatthias Springer 1065bd20756dSMatthias Springer return isTensorOp(xferOp) ? newXferOp->getResult(0) : Value(); 1066bd20756dSMatthias Springer }, 1067bd20756dSMatthias Springer /*outOfBoundsCase=*/ 1068bd20756dSMatthias Springer [&](OpBuilder &b, Location loc) { 1069bd20756dSMatthias Springer return isTensorOp(xferOp) ? source : Value(); 10700f241638SMatthias Springer }); 1071bd20756dSMatthias Springer 1072bd20756dSMatthias Springer if (isTensorOp(xferOp)) 1073bd20756dSMatthias Springer source = updatedSource; 10740f241638SMatthias Springer } 10750f241638SMatthias Springer 1076bd20756dSMatthias Springer if (isTensorOp(xferOp)) 1077bd20756dSMatthias Springer rewriter.replaceOp(xferOp, source); 1078bd20756dSMatthias Springer else 10790f241638SMatthias Springer rewriter.eraseOp(xferOp); 1080bd20756dSMatthias Springer 10810f241638SMatthias Springer return success(); 10820f241638SMatthias Springer } 10830f241638SMatthias Springer }; 10840f241638SMatthias Springer 1085a088bed4SMatthias Springer } // namespace lowering_n_d_unrolled 1086a088bed4SMatthias Springer 1087a088bed4SMatthias Springer namespace lowering_1_d { 1088a088bed4SMatthias Springer 10890f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as 10900f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which 10910f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast. 10920f241638SMatthias Springer template <typename OpTy> 10930f241638SMatthias Springer static Optional<int64_t> 10946825bfe2SNicolas Vasilache get1dMemrefIndices(OpBuilder &b, OpTy xferOp, Value iv, 10950f241638SMatthias Springer SmallVector<Value, 8> &memrefIndices) { 10967c38fd60SJacques Pienaar auto indices = xferOp.getIndices(); 10977c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 1098c537a943SNicolas Vasilache assert(xferOp.getTransferRank() > 0 && "unexpected 0-d transfer"); 10990f241638SMatthias Springer 11000f241638SMatthias Springer memrefIndices.append(indices.begin(), indices.end()); 11010f241638SMatthias Springer assert(map.getNumResults() == 1 && 11020f241638SMatthias Springer "Expected 1 permutation map result for 1D transfer"); 11030f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 11046825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 11050f241638SMatthias Springer auto dim = expr.getPosition(); 11066825bfe2SNicolas Vasilache AffineExpr d0, d1; 11076825bfe2SNicolas Vasilache bindDims(xferOp.getContext(), d0, d1); 11086825bfe2SNicolas Vasilache Value offset = memrefIndices[dim]; 11096825bfe2SNicolas Vasilache memrefIndices[dim] = makeComposedAffineApply(b, loc, d0 + d1, {offset, iv}); 11100f241638SMatthias Springer return dim; 11110f241638SMatthias Springer } 11120f241638SMatthias Springer 11130f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 11140f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 11150f241638SMatthias Springer return None; 11160f241638SMatthias Springer } 11170f241638SMatthias Springer 11180f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the 11190f241638SMatthias Springer /// operation. 11200f241638SMatthias Springer template <typename OpTy> 11210f241638SMatthias Springer struct Strategy1d; 11220f241638SMatthias Springer 11230f241638SMatthias Springer /// Codegen strategy for TransferReadOp. 11240f241638SMatthias Springer template <> 11250f241638SMatthias Springer struct Strategy1d<TransferReadOp> { 11266825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc, 11270f241638SMatthias Springer TransferReadOp xferOp, Value iv, 11280f241638SMatthias Springer ValueRange loopState) { 11290f241638SMatthias Springer SmallVector<Value, 8> indices; 11306825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 11310f241638SMatthias Springer auto vec = loopState[0]; 11320f241638SMatthias Springer 11330f241638SMatthias Springer // In case of out-of-bounds access, leave `vec` as is (was initialized with 11340f241638SMatthias Springer // padding value). 11350f241638SMatthias Springer auto nextVec = generateInBoundsCheck( 11366825bfe2SNicolas Vasilache b, xferOp, iv, dim, TypeRange(xferOp.getVectorType()), 11370f241638SMatthias Springer /*inBoundsCase=*/ 11386825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc) { 11397c38fd60SJacques Pienaar Value val = 11407c38fd60SJacques Pienaar b.create<memref::LoadOp>(loc, xferOp.getSource(), indices); 11417c5ecc8bSMogball return b.create<vector::InsertElementOp>(loc, val, vec, iv); 11420f241638SMatthias Springer }, 11430f241638SMatthias Springer /*outOfBoundsCase=*/ 11440f241638SMatthias Springer [&](OpBuilder & /*b*/, Location loc) { return vec; }); 11456825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc, nextVec); 11460f241638SMatthias Springer } 11470f241638SMatthias Springer 11486825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferReadOp xferOp) { 11490f241638SMatthias Springer // Inititalize vector with padding value. 11506825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 11516a8ba318SRiver Riddle return b.create<vector::SplatOp>(loc, xferOp.getVectorType(), 11527c38fd60SJacques Pienaar xferOp.getPadding()); 11530f241638SMatthias Springer } 11540f241638SMatthias Springer }; 11550f241638SMatthias Springer 11560f241638SMatthias Springer /// Codegen strategy for TransferWriteOp. 11570f241638SMatthias Springer template <> 11580f241638SMatthias Springer struct Strategy1d<TransferWriteOp> { 11596825bfe2SNicolas Vasilache static void generateForLoopBody(OpBuilder &b, Location loc, 11600f241638SMatthias Springer TransferWriteOp xferOp, Value iv, 11610f241638SMatthias Springer ValueRange /*loopState*/) { 11620f241638SMatthias Springer SmallVector<Value, 8> indices; 11636825bfe2SNicolas Vasilache auto dim = get1dMemrefIndices(b, xferOp, iv, indices); 11640f241638SMatthias Springer 11650f241638SMatthias Springer // Nothing to do in case of out-of-bounds access. 11660f241638SMatthias Springer generateInBoundsCheck( 11676825bfe2SNicolas Vasilache b, xferOp, iv, dim, 11686825bfe2SNicolas Vasilache /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 11696825bfe2SNicolas Vasilache auto val = 11707c38fd60SJacques Pienaar b.create<vector::ExtractElementOp>(loc, xferOp.getVector(), iv); 11717c38fd60SJacques Pienaar b.create<memref::StoreOp>(loc, val, xferOp.getSource(), indices); 11720f241638SMatthias Springer }); 11736825bfe2SNicolas Vasilache b.create<scf::YieldOp>(loc); 11740f241638SMatthias Springer } 11750f241638SMatthias Springer 11766825bfe2SNicolas Vasilache static Value initialLoopState(OpBuilder &b, TransferWriteOp xferOp) { 11776825bfe2SNicolas Vasilache return Value(); 11786825bfe2SNicolas Vasilache } 11790f241638SMatthias Springer }; 11800f241638SMatthias Springer 11810f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride. 11820f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) { 11830f241638SMatthias Springer int64_t offset; 11840f241638SMatthias Springer SmallVector<int64_t, 4> strides; 11850f241638SMatthias Springer auto successStrides = getStridesAndOffset(type, strides, offset); 11865017b0f8SMatthias Springer return succeeded(successStrides) && (strides.empty() || strides.back() == 1); 11870f241638SMatthias Springer } 11880f241638SMatthias Springer 11890f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 11900f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into 11910f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts: 11920f241638SMatthias Springer /// 11930f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension 11940f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 11950f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension 11960f241638SMatthias Springer /// 11970f241638SMatthias Springer /// This pattern generates IR as follows: 11980f241638SMatthias Springer /// 11990f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element. 12000f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 12010f241638SMatthias Springer /// depending on OpTy. 12020f241638SMatthias Springer /// 12030f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 12040f241638SMatthias Springer /// can be generated instead of TransferOp1dConversion. Add such a pattern 12050f241638SMatthias Springer /// to ConvertVectorToLLVM. 12060f241638SMatthias Springer /// 12070f241638SMatthias Springer /// E.g.: 12080f241638SMatthias Springer /// ``` 12090f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b] 12100f241638SMatthias Springer /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 12110f241638SMatthias Springer /// : vector<9xf32>, memref<?x?xf32> 12120f241638SMatthias Springer /// ``` 12130f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR: 12140f241638SMatthias Springer /// ``` 12150f241638SMatthias Springer /// for i = 0 to 9 { 12160f241638SMatthias Springer /// %t = vector.extractelement %vec[i] : vector<9xf32> 12170f241638SMatthias Springer /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 12180f241638SMatthias Springer /// } 12190f241638SMatthias Springer /// ``` 12200f241638SMatthias Springer template <typename OpTy> 12212ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 12222ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 12230f241638SMatthias Springer 12240f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 12250f241638SMatthias Springer PatternRewriter &rewriter) const override { 1226c537a943SNicolas Vasilache // TODO: support 0-d corner case. 1227c537a943SNicolas Vasilache if (xferOp.getTransferRank() == 0) 1228c537a943SNicolas Vasilache return failure(); 12297c38fd60SJacques Pienaar auto map = xferOp.getPermutationMap(); 12300f241638SMatthias Springer auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 12310f241638SMatthias Springer 12320f241638SMatthias Springer if (!memRefType) 12330f241638SMatthias Springer return failure(); 12340f241638SMatthias Springer if (xferOp.getVectorType().getRank() != 1) 12350f241638SMatthias Springer return failure(); 12360f241638SMatthias Springer if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 12370f241638SMatthias Springer return failure(); // Handled by ConvertVectorToLLVM 12380f241638SMatthias Springer 12390f241638SMatthias Springer // Loop bounds, step, state... 12406825bfe2SNicolas Vasilache Location loc = xferOp.getLoc(); 12410f241638SMatthias Springer auto vecType = xferOp.getVectorType(); 1242a54f4eaeSMogball auto lb = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1243a54f4eaeSMogball auto ub = 1244a54f4eaeSMogball rewriter.create<arith::ConstantIndexOp>(loc, vecType.getDimSize(0)); 1245a54f4eaeSMogball auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); 12466825bfe2SNicolas Vasilache auto loopState = Strategy1d<OpTy>::initialLoopState(rewriter, xferOp); 12470f241638SMatthias Springer 12480f241638SMatthias Springer // Generate for loop. 12490f241638SMatthias Springer rewriter.replaceOpWithNewOp<scf::ForOp>( 12500f241638SMatthias Springer xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 12516825bfe2SNicolas Vasilache [&](OpBuilder &b, Location loc, Value iv, ValueRange loopState) { 12526825bfe2SNicolas Vasilache Strategy1d<OpTy>::generateForLoopBody(b, loc, xferOp, iv, loopState); 12530f241638SMatthias Springer }); 12540f241638SMatthias Springer 12550f241638SMatthias Springer return success(); 12560f241638SMatthias Springer } 12570f241638SMatthias Springer }; 12584ead2cf7SAlex Zinenko 1259a088bed4SMatthias Springer } // namespace lowering_1_d 1260df63eedeSBenjamin Kramer } // namespace 1261df63eedeSBenjamin Kramer 126247f175b0SRiver Riddle void mlir::populateVectorToSCFConversionPatterns( 1263dc4e913bSChris Lattner RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 12640f241638SMatthias Springer if (options.unroll) { 1265a088bed4SMatthias Springer patterns.add<lowering_n_d_unrolled::UnrollTransferReadConversion, 1266a088bed4SMatthias Springer lowering_n_d_unrolled::UnrollTransferWriteConversion>( 12672ca887deSMatthias Springer patterns.getContext(), options); 12680f241638SMatthias Springer } else { 1269a088bed4SMatthias Springer patterns.add<lowering_n_d::PrepareTransferReadConversion, 1270a088bed4SMatthias Springer lowering_n_d::PrepareTransferWriteConversion, 1271a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferReadOp>, 1272a088bed4SMatthias Springer lowering_n_d::TransferOpConversion<TransferWriteOp>>( 1273a088bed4SMatthias Springer patterns.getContext(), options); 12740f241638SMatthias Springer } 12750f241638SMatthias Springer 12762ca887deSMatthias Springer if (options.targetRank == 1) { 1277a088bed4SMatthias Springer patterns.add<lowering_1_d::TransferOp1dConversion<TransferReadOp>, 1278a088bed4SMatthias Springer lowering_1_d::TransferOp1dConversion<TransferWriteOp>>( 1279a088bed4SMatthias Springer patterns.getContext(), options); 12800f241638SMatthias Springer } 12814ead2cf7SAlex Zinenko } 12823393cc4cSNicolas Vasilache 12835f9e0466SNicolas Vasilache namespace { 12845f9e0466SNicolas Vasilache 12855f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass 12865f9e0466SNicolas Vasilache : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 12875f9e0466SNicolas Vasilache ConvertVectorToSCFPass() = default; 12885f9e0466SNicolas Vasilache ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 12895f9e0466SNicolas Vasilache this->fullUnroll = options.unroll; 12902ca887deSMatthias Springer this->targetRank = options.targetRank; 1291fb7ec1f1SMatthias Springer this->lowerPermutationMaps = options.lowerPermutationMaps; 1292558e7401SMatthias Springer this->lowerTensors = options.lowerTensors; 12935f9e0466SNicolas Vasilache } 12945f9e0466SNicolas Vasilache 129541574554SRiver Riddle void runOnOperation() override { 12962ca887deSMatthias Springer VectorTransferToSCFOptions options; 1297fb7ec1f1SMatthias Springer options.unroll = fullUnroll; 1298fb7ec1f1SMatthias Springer options.targetRank = targetRank; 1299fb7ec1f1SMatthias Springer options.lowerPermutationMaps = lowerPermutationMaps; 1300558e7401SMatthias Springer options.lowerTensors = lowerTensors; 1301fb7ec1f1SMatthias Springer 1302fb7ec1f1SMatthias Springer // Lower permutation maps first. 1303fb7ec1f1SMatthias Springer if (lowerPermutationMaps) { 130447f175b0SRiver Riddle RewritePatternSet lowerTransferPatterns(&getContext()); 1305fb7ec1f1SMatthias Springer mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( 1306fb7ec1f1SMatthias Springer lowerTransferPatterns); 130741574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), 1308fb7ec1f1SMatthias Springer std::move(lowerTransferPatterns)); 1309fb7ec1f1SMatthias Springer } 13102ca887deSMatthias Springer 131147f175b0SRiver Riddle RewritePatternSet patterns(&getContext()); 13122ca887deSMatthias Springer populateVectorToSCFConversionPatterns(patterns, options); 131341574554SRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 13145f9e0466SNicolas Vasilache } 13155f9e0466SNicolas Vasilache }; 13165f9e0466SNicolas Vasilache 13175f9e0466SNicolas Vasilache } // namespace 13185f9e0466SNicolas Vasilache 13195f9e0466SNicolas Vasilache std::unique_ptr<Pass> 13205f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 13215f9e0466SNicolas Vasilache return std::make_unique<ConvertVectorToSCFPass>(options); 13225f9e0466SNicolas Vasilache } 1323