10f241638SMatthias Springer //===- VectorToSCF.cpp - Convert vector to SCF dialect ----------*- C++ -*-===// 24ead2cf7SAlex Zinenko // 34ead2cf7SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44ead2cf7SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 54ead2cf7SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64ead2cf7SAlex Zinenko // 74ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 84ead2cf7SAlex Zinenko // 90f241638SMatthias Springer // This file implements lowering of vector transfer operations to SCF. 104ead2cf7SAlex Zinenko // 114ead2cf7SAlex Zinenko //===----------------------------------------------------------------------===// 124ead2cf7SAlex Zinenko 134ead2cf7SAlex Zinenko #include <type_traits> 144ead2cf7SAlex Zinenko 154ead2cf7SAlex Zinenko #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 165f9e0466SNicolas Vasilache 175f9e0466SNicolas Vasilache #include "../PassDetail.h" 184ead2cf7SAlex Zinenko #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" 19e2310704SJulian Gross #include "mlir/Dialect/MemRef/EDSC/Intrinsics.h" 204ead2cf7SAlex Zinenko #include "mlir/Dialect/SCF/EDSC/Intrinsics.h" 214ead2cf7SAlex Zinenko #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" 224ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/EDSC/Intrinsics.h" 234ead2cf7SAlex Zinenko #include "mlir/Dialect/Vector/VectorOps.h" 247c3c5b11SNicolas Vasilache #include "mlir/Dialect/Vector/VectorUtils.h" 254ead2cf7SAlex Zinenko #include "mlir/IR/Builders.h" 265f9e0466SNicolas Vasilache #include "mlir/Pass/Pass.h" 27b6eb26fdSRiver Riddle #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 285f9e0466SNicolas Vasilache #include "mlir/Transforms/Passes.h" 294ead2cf7SAlex Zinenko 304ead2cf7SAlex Zinenko using namespace mlir; 314ead2cf7SAlex Zinenko using namespace mlir::edsc; 324ead2cf7SAlex Zinenko using namespace mlir::edsc::intrinsics; 334ead2cf7SAlex Zinenko using vector::TransferReadOp; 344ead2cf7SAlex Zinenko using vector::TransferWriteOp; 354ead2cf7SAlex Zinenko 36350dadaaSBenjamin Kramer namespace { 370f241638SMatthias Springer 380f241638SMatthias Springer /// Attribute name used for labeling transfer ops during progressive lowering. 390f241638SMatthias Springer static const char kPassLabel[] = "__vector_to_scf_lowering__"; 400f241638SMatthias Springer 41*2ca887deSMatthias Springer /// Patterns that inherit from this struct have access to 42*2ca887deSMatthias Springer /// VectorTransferToSCFOptions. 43*2ca887deSMatthias Springer template <typename OpTy> 44*2ca887deSMatthias Springer struct VectorToSCFPattern : public OpRewritePattern<OpTy> { 45*2ca887deSMatthias Springer explicit VectorToSCFPattern(MLIRContext *context, 46*2ca887deSMatthias Springer VectorTransferToSCFOptions opt) 47*2ca887deSMatthias Springer : OpRewritePattern<OpTy>(context), options(opt) {} 48*2ca887deSMatthias Springer 49*2ca887deSMatthias Springer VectorTransferToSCFOptions options; 50*2ca887deSMatthias Springer }; 510f241638SMatthias Springer 520f241638SMatthias Springer /// Given a MemRefType with VectorType element type, unpack one dimension from 530f241638SMatthias Springer /// the VectorType into the MemRefType. 544ead2cf7SAlex Zinenko /// 550f241638SMatthias Springer /// E.g.: memref<9xvector<5x6xf32>> --> memref<9x5xvector<6xf32>> 560f241638SMatthias Springer static MemRefType unpackOneDim(MemRefType type) { 570f241638SMatthias Springer auto vectorType = type.getElementType().dyn_cast<VectorType>(); 580f241638SMatthias Springer auto memrefShape = type.getShape(); 590f241638SMatthias Springer SmallVector<int64_t, 8> newMemrefShape; 600f241638SMatthias Springer newMemrefShape.append(memrefShape.begin(), memrefShape.end()); 610f241638SMatthias Springer newMemrefShape.push_back(vectorType.getDimSize(0)); 620f241638SMatthias Springer return MemRefType::get(newMemrefShape, 630f241638SMatthias Springer VectorType::get(vectorType.getShape().drop_front(), 640f241638SMatthias Springer vectorType.getElementType())); 654ead2cf7SAlex Zinenko } 664ead2cf7SAlex Zinenko 670f241638SMatthias Springer /// Helper data structure for data and mask buffers. 680f241638SMatthias Springer struct BufferAllocs { 690f241638SMatthias Springer Value dataBuffer; 700f241638SMatthias Springer Value maskBuffer; 714ead2cf7SAlex Zinenko }; 724ead2cf7SAlex Zinenko 730f241638SMatthias Springer /// Allocate temporary buffers for data (vector) and mask (if present). 740f241638SMatthias Springer /// TODO: Parallelism and threadlocal considerations. 750f241638SMatthias Springer template <typename OpTy> 760f241638SMatthias Springer static BufferAllocs allocBuffers(OpTy xferOp) { 77247e185dSNicolas Vasilache auto &b = ScopedContext::getBuilderRef(); 78247e185dSNicolas Vasilache OpBuilder::InsertionGuard guard(b); 79a4b8c2deSJakub Lichman Operation *scope = 800f241638SMatthias Springer xferOp->template getParentWithTrait<OpTrait::AutomaticAllocationScope>(); 81a4b8c2deSJakub Lichman assert(scope && "Expected op to be inside automatic allocation scope"); 82a4b8c2deSJakub Lichman b.setInsertionPointToStart(&scope->getRegion(0).front()); 830f241638SMatthias Springer 840f241638SMatthias Springer BufferAllocs result; 850f241638SMatthias Springer auto bufferType = MemRefType::get({}, xferOp.getVectorType()); 860f241638SMatthias Springer result.dataBuffer = memref_alloca(bufferType).value; 870f241638SMatthias Springer 880f241638SMatthias Springer if (xferOp.mask()) { 890f241638SMatthias Springer auto maskType = MemRefType::get({}, xferOp.mask().getType()); 900f241638SMatthias Springer Value maskBuffer = memref_alloca(maskType); 910f241638SMatthias Springer memref_store(xferOp.mask(), maskBuffer); 920f241638SMatthias Springer result.maskBuffer = memref_load(maskBuffer); 93247e185dSNicolas Vasilache } 94247e185dSNicolas Vasilache 950f241638SMatthias Springer return result; 961870e787SNicolas Vasilache } 977c3c5b11SNicolas Vasilache 980f241638SMatthias Springer /// Given a vector transfer op, calculate which dimension of the `source` 990f241638SMatthias Springer /// memref should be unpacked in the next application of TransferOpConversion. 1000f241638SMatthias Springer /// A return value of None indicates a broadcast. 1010f241638SMatthias Springer template <typename OpTy> 1020f241638SMatthias Springer static Optional<int64_t> unpackedDim(OpTy xferOp) { 1030f241638SMatthias Springer auto map = xferOp.permutation_map(); 1040f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 1050f241638SMatthias Springer return expr.getPosition(); 1067c3c5b11SNicolas Vasilache } 1070f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 1080f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 1090f241638SMatthias Springer return None; 1100f241638SMatthias Springer } 1110f241638SMatthias Springer 1120f241638SMatthias Springer /// Compute the permutation map for the new (N-1)-D vector transfer op. This 1130f241638SMatthias Springer /// map is identical to the current permutation map, but the first result is 1140f241638SMatthias Springer /// omitted. 1150f241638SMatthias Springer template <typename OpTy> 1160f241638SMatthias Springer static AffineMap unpackedPermutationMap(OpTy xferOp, OpBuilder &builder) { 1170f241638SMatthias Springer auto map = xferOp.permutation_map(); 1180f241638SMatthias Springer return AffineMap::get(map.getNumDims(), 0, map.getResults().drop_front(), 1190f241638SMatthias Springer builder.getContext()); 1200f241638SMatthias Springer } 1210f241638SMatthias Springer 1220f241638SMatthias Springer /// Calculate the indices for the new vector transfer op. 1230f241638SMatthias Springer /// 1240f241638SMatthias Springer /// E.g.: transfer_read %A[%a, %b, %c, %d] ... : vector<5x4x3xf32> ... 1250f241638SMatthias Springer /// --> transfer_read %A[%a, %b + iv, %c, %d] ... vector<4x3f32> 1260f241638SMatthias Springer /// ^^^^^^ 1270f241638SMatthias Springer /// `iv` is the iteration variable of the (new) surrounding loop. 1280f241638SMatthias Springer template <typename OpTy> 1290f241638SMatthias Springer static void getXferIndices(OpTy xferOp, Value iv, 1300f241638SMatthias Springer SmallVector<Value, 8> &indices) { 1310f241638SMatthias Springer typename OpTy::Adaptor adaptor(xferOp); 1320f241638SMatthias Springer // Corresponding memref dim of the vector dim that is unpacked. 1330f241638SMatthias Springer auto dim = unpackedDim(xferOp); 1340f241638SMatthias Springer auto prevIndices = adaptor.indices(); 1350f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 1360f241638SMatthias Springer 1370f241638SMatthias Springer bool isBroadcast = !dim.hasValue(); 1380f241638SMatthias Springer if (!isBroadcast) { 1390f241638SMatthias Springer using edsc::op::operator+; 1400f241638SMatthias Springer indices[dim.getValue()] = adaptor.indices()[dim.getValue()] + iv; 1410f241638SMatthias Springer } 1420f241638SMatthias Springer } 1430f241638SMatthias Springer 1440f241638SMatthias Springer static void maybeYieldValue(bool hasRetVal, OpBuilder builder, Location loc, 1450f241638SMatthias Springer Value value) { 1460f241638SMatthias Springer if (hasRetVal) { 1470f241638SMatthias Springer builder.create<scf::YieldOp>(loc, value); 1480f241638SMatthias Springer } else { 1490f241638SMatthias Springer builder.create<scf::YieldOp>(loc); 1500f241638SMatthias Springer } 1510f241638SMatthias Springer } 1520f241638SMatthias Springer 1530f241638SMatthias Springer /// Generates a boolean Value that is true if the iv-th bit in xferOp's mask 1540f241638SMatthias Springer /// is set to true. No such check is generated under following circumstances: 1550f241638SMatthias Springer /// * xferOp does not have a mask. 1560f241638SMatthias Springer /// * xferOp's mask is not 1D. (In case of (N>1)-D, a subvector of the mask is 1570f241638SMatthias Springer /// computed and attached to the new transfer op in the pattern.) 1580f241638SMatthias Springer /// * The to-be-unpacked dim of xferOp is a broadcast. 1590f241638SMatthias Springer template <typename OpTy> 1600f241638SMatthias Springer static Value generateMaskCheck(OpBuilder &builder, OpTy xferOp, Value iv) { 1610f241638SMatthias Springer if (!xferOp.mask()) 1620f241638SMatthias Springer return Value(); 1630f241638SMatthias Springer if (xferOp.getMaskType().getRank() != 1) 1640f241638SMatthias Springer return Value(); 1650f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) 1660f241638SMatthias Springer return Value(); 1670f241638SMatthias Springer 1680f241638SMatthias Springer auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 1690f241638SMatthias Springer return vector_extract_element(xferOp.mask(), ivI32).value; 1700f241638SMatthias Springer } 1710f241638SMatthias Springer 1720f241638SMatthias Springer /// Helper function TransferOpConversion and TransferOp1dConversion. 1730f241638SMatthias Springer /// Generate an in-bounds check if the transfer op may go out-of-bounds on the 1740f241638SMatthias Springer /// specified dimension `dim` with the loop iteration variable `iv`. 1750f241638SMatthias Springer /// E.g., when unpacking dimension 0 from: 1760f241638SMatthias Springer /// ``` 1770f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b] %cst 1780f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?xf32> 1790f241638SMatthias Springer /// ``` 1800f241638SMatthias Springer /// An if check similar to this will be generated inside the loop: 1810f241638SMatthias Springer /// ``` 1820f241638SMatthias Springer /// %d = memref.dim %A, %c0 : memref<?x?xf32> 1830f241638SMatthias Springer /// if (%a + iv < %d) { 1840f241638SMatthias Springer /// (in-bounds case) 1850f241638SMatthias Springer /// } else { 1860f241638SMatthias Springer /// (out-of-bounds case) 1870f241638SMatthias Springer /// } 1880f241638SMatthias Springer /// ``` 1890f241638SMatthias Springer /// 1900f241638SMatthias Springer /// If the transfer is 1D and has a mask, this function generates a more complex 1910f241638SMatthias Springer /// check also accounts for potentially masked out elements. 1920f241638SMatthias Springer /// 1930f241638SMatthias Springer /// This function variant returns the value returned by `inBoundsCase` or 1940f241638SMatthias Springer /// `outOfBoundsCase`. The MLIR type of the return value must be specified in 1950f241638SMatthias Springer /// `resultTypes`. 1960f241638SMatthias Springer template <typename OpTy> 1970f241638SMatthias Springer static Value generateInBoundsCheck( 1980f241638SMatthias Springer OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim, 1990f241638SMatthias Springer TypeRange resultTypes, 2000f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> inBoundsCase, 2010f241638SMatthias Springer function_ref<Value(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 2020f241638SMatthias Springer bool hasRetVal = !resultTypes.empty(); 2030f241638SMatthias Springer Value cond; // Condition to be built... 2040f241638SMatthias Springer 2050f241638SMatthias Springer // Condition check 1: Access in-bounds? 2060f241638SMatthias Springer bool isBroadcast = !dim.hasValue(); // No in-bounds check for broadcasts. 2070f241638SMatthias Springer if (!xferOp.isDimInBounds(0) && !isBroadcast) { 2080f241638SMatthias Springer auto memrefDim = 2090f241638SMatthias Springer memref_dim(xferOp.source(), std_constant_index(dim.getValue())); 2100f241638SMatthias Springer using edsc::op::operator+; 2110f241638SMatthias Springer auto memrefIdx = xferOp.indices()[dim.getValue()] + iv; 2120f241638SMatthias Springer cond = std_cmpi_sgt(memrefDim.value, memrefIdx); 2130f241638SMatthias Springer } 2140f241638SMatthias Springer 2150f241638SMatthias Springer // Condition check 2: Masked in? 2160f241638SMatthias Springer if (auto maskCond = generateMaskCheck(builder, xferOp, iv)) { 2170f241638SMatthias Springer if (cond) { 2180f241638SMatthias Springer cond = builder.create<AndOp>(xferOp.getLoc(), cond, maskCond); 2190f241638SMatthias Springer } else { 2200f241638SMatthias Springer cond = maskCond; 2210f241638SMatthias Springer } 2220f241638SMatthias Springer } 2230f241638SMatthias Springer 2240f241638SMatthias Springer // If the condition is non-empty, generate an SCF::IfOp. 2250f241638SMatthias Springer if (cond) { 2260f241638SMatthias Springer auto check = builder.create<scf::IfOp>( 2270f241638SMatthias Springer xferOp.getLoc(), resultTypes, cond, 2280f241638SMatthias Springer /*thenBuilder=*/ 2290f241638SMatthias Springer [&](OpBuilder &builder, Location loc) { 2300f241638SMatthias Springer maybeYieldValue(hasRetVal, builder, loc, inBoundsCase(builder, loc)); 231cadb7ccfSAlex Zinenko }, 2320f241638SMatthias Springer /*elseBuilder=*/ 2330f241638SMatthias Springer [&](OpBuilder &builder, Location loc) { 2340f241638SMatthias Springer if (outOfBoundsCase) { 2350f241638SMatthias Springer maybeYieldValue(hasRetVal, builder, loc, 2360f241638SMatthias Springer outOfBoundsCase(builder, loc)); 2377c3c5b11SNicolas Vasilache } else { 2380f241638SMatthias Springer builder.create<scf::YieldOp>(loc); 2397c3c5b11SNicolas Vasilache } 2407c3c5b11SNicolas Vasilache }); 2417c3c5b11SNicolas Vasilache 2420f241638SMatthias Springer return hasRetVal ? check.getResult(0) : Value(); 2434ead2cf7SAlex Zinenko } 2444ead2cf7SAlex Zinenko 2450f241638SMatthias Springer // Condition is empty, no need for an SCF::IfOp. 2460f241638SMatthias Springer return inBoundsCase(builder, xferOp.getLoc()); 2470f241638SMatthias Springer } 2480f241638SMatthias Springer 2490f241638SMatthias Springer /// In this function variant, `inBoundsCase` and `outOfBoundsCase` do not have 2500f241638SMatthias Springer /// a return value. Consequently, this function does not have a return value. 2510f241638SMatthias Springer template <typename OpTy> 2520f241638SMatthias Springer static void generateInBoundsCheck( 2530f241638SMatthias Springer OpTy xferOp, Value iv, OpBuilder &builder, Optional<int64_t> dim, 2540f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> inBoundsCase, 2550f241638SMatthias Springer function_ref<void(OpBuilder &, Location)> outOfBoundsCase = nullptr) { 2560f241638SMatthias Springer generateInBoundsCheck( 2570f241638SMatthias Springer xferOp, iv, builder, dim, /*resultTypes=*/TypeRange(), 2580f241638SMatthias Springer /*inBoundsCase=*/ 2590f241638SMatthias Springer [&](OpBuilder &builder, Location loc) { 2600f241638SMatthias Springer inBoundsCase(builder, loc); 2610f241638SMatthias Springer return Value(); 2620f241638SMatthias Springer }, 2630f241638SMatthias Springer /*outOfBoundsCase=*/ 2640f241638SMatthias Springer [&](OpBuilder &builder, Location loc) { 2650f241638SMatthias Springer if (outOfBoundsCase) 2660f241638SMatthias Springer outOfBoundsCase(builder, loc); 2670f241638SMatthias Springer return Value(); 2680f241638SMatthias Springer }); 2690f241638SMatthias Springer } 2700f241638SMatthias Springer 2710f241638SMatthias Springer /// Given an ArrayAttr, return a copy where the first element is dropped. 2720f241638SMatthias Springer static ArrayAttr dropFirstElem(OpBuilder &builder, ArrayAttr attr) { 2730f241638SMatthias Springer if (!attr) 2740f241638SMatthias Springer return attr; 2750f241638SMatthias Springer return ArrayAttr::get(builder.getContext(), attr.getValue().drop_front()); 2760f241638SMatthias Springer } 2770f241638SMatthias Springer 2780f241638SMatthias Springer /// Add the pass label to a vector transfer op if its rank is not the target 2790f241638SMatthias Springer /// rank. 2800f241638SMatthias Springer template <typename OpTy> 281*2ca887deSMatthias Springer static void maybeApplyPassLabel(OpBuilder &builder, OpTy newXferOp, 282*2ca887deSMatthias Springer unsigned targetRank) { 283*2ca887deSMatthias Springer if (newXferOp.getVectorType().getRank() > targetRank) 2840f241638SMatthias Springer newXferOp->setAttr(kPassLabel, builder.getUnitAttr()); 2850f241638SMatthias Springer } 2860f241638SMatthias Springer 2870f241638SMatthias Springer /// Given a transfer op, find the memref from which the mask is loaded. This 2880f241638SMatthias Springer /// is similar to Strategy<TransferWriteOp>::getBuffer. 2890f241638SMatthias Springer template <typename OpTy> 2900f241638SMatthias Springer static Value getMaskBuffer(OpTy xferOp) { 2910f241638SMatthias Springer assert(xferOp.mask() && "Expected that transfer op has mask"); 2920f241638SMatthias Springer auto loadOp = xferOp.mask().template getDefiningOp<memref::LoadOp>(); 2930f241638SMatthias Springer assert(loadOp && "Expected transfer op mask produced by LoadOp"); 2940f241638SMatthias Springer return loadOp.getMemRef(); 2950f241638SMatthias Springer } 2960f241638SMatthias Springer 2970f241638SMatthias Springer /// Codegen strategy, depending on the operation. 2980f241638SMatthias Springer template <typename OpTy> 2990f241638SMatthias Springer struct Strategy; 3000f241638SMatthias Springer 3010f241638SMatthias Springer /// Code strategy for vector TransferReadOp. 3024ead2cf7SAlex Zinenko template <> 3030f241638SMatthias Springer struct Strategy<TransferReadOp> { 3040f241638SMatthias Springer /// Find the StoreOp that is used for writing the current TransferReadOp's 3050f241638SMatthias Springer /// result to the temporary buffer allocation. 3060f241638SMatthias Springer static memref::StoreOp getStoreOp(TransferReadOp xferOp) { 3070f241638SMatthias Springer assert(xferOp->hasOneUse() && "Expected exactly one use of TransferReadOp"); 3080f241638SMatthias Springer auto storeOp = dyn_cast<memref::StoreOp>((*xferOp->use_begin()).getOwner()); 3090f241638SMatthias Springer assert(storeOp && "Expected TransferReadOp result used by StoreOp"); 3100f241638SMatthias Springer return storeOp; 3117c3c5b11SNicolas Vasilache } 3124ead2cf7SAlex Zinenko 3130f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferReadOps are 3140f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 3150f241638SMatthias Springer /// of the buffer allocation: 3160f241638SMatthias Springer /// ``` 3170f241638SMatthias Springer /// %vec = vector.transfer_read ... { __vector_to_scf_lowering__ } ... 3180f241638SMatthias Springer /// memref.store %vec, %buf[...] ... 3190f241638SMatthias Springer /// ``` 3200f241638SMatthias Springer static Value getBuffer(TransferReadOp xferOp) { 3210f241638SMatthias Springer return getStoreOp(xferOp).getMemRef(); 3221870e787SNicolas Vasilache } 3230f241638SMatthias Springer 3240f241638SMatthias Springer /// Retrieve the indices of the current StoreOp that stores into the buffer. 3250f241638SMatthias Springer static void getBufferIndices(TransferReadOp xferOp, 3260f241638SMatthias Springer SmallVector<Value, 8> &indices) { 3270f241638SMatthias Springer auto storeOp = getStoreOp(xferOp); 3280f241638SMatthias Springer auto prevIndices = memref::StoreOpAdaptor(storeOp).indices(); 3290f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 3300f241638SMatthias Springer } 3310f241638SMatthias Springer 3320f241638SMatthias Springer /// Rewrite the TransferReadOp, assuming that there are no out-of-bounds 3330f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 3340f241638SMatthias Springer /// 3350f241638SMatthias Springer /// 1. Generate a new (N-1)-d TransferReadOp using the loop iteration 3360f241638SMatthias Springer /// variable `iv`. 3370f241638SMatthias Springer /// 2. Store the result into the (already `vector.type_cast`ed) buffer. 3380f241638SMatthias Springer /// 3390f241638SMatthias Springer /// E.g.: 3400f241638SMatthias Springer /// ``` 3410f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b, %c], %cst 3420f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4x3xf32> 3430f241638SMatthias Springer /// memref.store %vec, %buf[%i] : memref<5xvector<4x3xf32>> 3440f241638SMatthias Springer /// ``` 3450f241638SMatthias Springer /// Is rewritten to: 3460f241638SMatthias Springer /// ``` 3470f241638SMatthias Springer /// %casted = vector.type_cast %buf 3480f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 3490f241638SMatthias Springer /// for %j = 0 to 4 { 3500f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a+%i, %b+%j, %c], %cst 3510f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<3xf32> 3520f241638SMatthias Springer /// memref.store %vec, %casted[%i, %j] : memref<5x4xvector<3xf32>> 3530f241638SMatthias Springer /// } 3540f241638SMatthias Springer /// ``` 3550f241638SMatthias Springer /// 3560f241638SMatthias Springer /// Note: The loop and type cast are generated in TransferOpConversion. 3570f241638SMatthias Springer /// The original TransferReadOp and store op are deleted in `cleanup`. 3580f241638SMatthias Springer /// Note: The `mask` operand is set in TransferOpConversion. 359*2ca887deSMatthias Springer static TransferReadOp rewriteOp(OpBuilder &builder, 360*2ca887deSMatthias Springer VectorTransferToSCFOptions options, 361*2ca887deSMatthias Springer TransferReadOp xferOp, Value buffer, 362*2ca887deSMatthias Springer Value iv) { 3630f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 3640f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 3650f241638SMatthias Springer storeIndices.push_back(iv); 3660f241638SMatthias Springer 3670f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 3680f241638SMatthias Springer getXferIndices(xferOp, iv, xferIndices); 3690f241638SMatthias Springer 3700f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 3710f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 3720f241638SMatthias Springer auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); 3730f241638SMatthias Springer auto newXfer = 3740f241638SMatthias Springer vector_transfer_read( 3750f241638SMatthias Springer vecType, xferOp.source(), xferIndices, 3760f241638SMatthias Springer AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), 3770f241638SMatthias Springer xferOp.padding(), Value(), inBoundsAttr) 3780f241638SMatthias Springer .value; 3790f241638SMatthias Springer 3800f241638SMatthias Springer maybeApplyPassLabel(builder, 381*2ca887deSMatthias Springer dyn_cast<TransferReadOp>(newXfer.getDefiningOp()), 382*2ca887deSMatthias Springer options.targetRank); 3830f241638SMatthias Springer 3840f241638SMatthias Springer memref_store(newXfer, buffer, storeIndices); 3850f241638SMatthias Springer return newXfer.getDefiningOp<TransferReadOp>(); 3860f241638SMatthias Springer } 3870f241638SMatthias Springer 3880f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension: Write 3890f241638SMatthias Springer /// padding value to the temporary buffer. 3900f241638SMatthias Springer static void handleOutOfBoundsDim(OpBuilder & /*builder*/, 3910f241638SMatthias Springer TransferReadOp xferOp, Value buffer, 3920f241638SMatthias Springer Value iv) { 3930f241638SMatthias Springer SmallVector<Value, 8> storeIndices; 3940f241638SMatthias Springer getBufferIndices(xferOp, storeIndices); 3950f241638SMatthias Springer storeIndices.push_back(iv); 3960f241638SMatthias Springer 3970f241638SMatthias Springer auto bufferType = buffer.getType().dyn_cast<ShapedType>(); 3980f241638SMatthias Springer auto vecType = bufferType.getElementType().dyn_cast<VectorType>(); 3990f241638SMatthias Springer auto vec = std_splat(vecType, xferOp.padding()); 4000f241638SMatthias Springer memref_store(vec, buffer, storeIndices); 4010f241638SMatthias Springer } 4020f241638SMatthias Springer 4030f241638SMatthias Springer /// Cleanup after rewriting the op. 4040f241638SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferReadOp xferOp) { 4050f241638SMatthias Springer rewriter.eraseOp(getStoreOp(xferOp)); 4060f241638SMatthias Springer rewriter.eraseOp(xferOp); 4070f241638SMatthias Springer } 4084ead2cf7SAlex Zinenko }; 4097c3c5b11SNicolas Vasilache 4100f241638SMatthias Springer /// Codegen strategy for vector TransferWriteOp. 4110f241638SMatthias Springer template <> 4120f241638SMatthias Springer struct Strategy<TransferWriteOp> { 4130f241638SMatthias Springer /// Find the temporary buffer allocation. All labeled TransferWriteOps are 4140f241638SMatthias Springer /// used like this, where %buf is either the buffer allocation or a type cast 4150f241638SMatthias Springer /// of the buffer allocation: 4160f241638SMatthias Springer /// ``` 4170f241638SMatthias Springer /// %vec = memref.load %buf[...] ... 4180f241638SMatthias Springer /// vector.transfer_write %vec ... { __vector_to_scf_lowering__ } ... 4190f241638SMatthias Springer /// ``` 4200f241638SMatthias Springer static Value getBuffer(TransferWriteOp xferOp) { 4210f241638SMatthias Springer auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 4220f241638SMatthias Springer assert(loadOp && "Expected transfer op vector produced by LoadOp"); 4230f241638SMatthias Springer return loadOp.getMemRef(); 4247c3c5b11SNicolas Vasilache } 4254ead2cf7SAlex Zinenko 4260f241638SMatthias Springer /// Retrieve the indices of the current LoadOp that loads from the buffer. 4270f241638SMatthias Springer static void getBufferIndices(TransferWriteOp xferOp, 4280f241638SMatthias Springer SmallVector<Value, 8> &indices) { 4290f241638SMatthias Springer auto loadOp = xferOp.vector().getDefiningOp<memref::LoadOp>(); 4300f241638SMatthias Springer auto prevIndices = memref::LoadOpAdaptor(loadOp).indices(); 4310f241638SMatthias Springer indices.append(prevIndices.begin(), prevIndices.end()); 4320f241638SMatthias Springer } 4330f241638SMatthias Springer 4340f241638SMatthias Springer /// Rewrite the TransferWriteOp, assuming that there are no out-of-bounds 4350f241638SMatthias Springer /// accesses on the to-be-unpacked dimension. 4360f241638SMatthias Springer /// 4370f241638SMatthias Springer /// 1. Load an (N-1)-d vector from the (already `vector.type_cast`ed) buffer, 4380f241638SMatthias Springer /// using the loop iteration variable `iv`. 4390f241638SMatthias Springer /// 2. Generate a new (N-1)-d TransferWriteOp, writing the loaded vector back 4400f241638SMatthias Springer /// to memory. 4410f241638SMatthias Springer /// 4420f241638SMatthias Springer /// Note: For more details, see comments on Strategy<TransferReadOp>. 443*2ca887deSMatthias Springer static TransferWriteOp rewriteOp(OpBuilder &builder, 444*2ca887deSMatthias Springer VectorTransferToSCFOptions options, 445*2ca887deSMatthias Springer TransferWriteOp xferOp, Value buffer, 446*2ca887deSMatthias Springer Value iv) { 4470f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 4480f241638SMatthias Springer getBufferIndices(xferOp, loadIndices); 4490f241638SMatthias Springer loadIndices.push_back(iv); 4500f241638SMatthias Springer 4510f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 4520f241638SMatthias Springer getXferIndices(xferOp, iv, xferIndices); 4530f241638SMatthias Springer 4540f241638SMatthias Springer auto vec = memref_load(buffer, loadIndices); 4550f241638SMatthias Springer auto inBoundsAttr = dropFirstElem(builder, xferOp.in_boundsAttr()); 4560f241638SMatthias Springer auto newXfer = vector_transfer_write( 4570f241638SMatthias Springer Type(), vec, xferOp.source(), xferIndices, 4580f241638SMatthias Springer AffineMapAttr::get(unpackedPermutationMap(xferOp, builder)), Value(), 4590f241638SMatthias Springer inBoundsAttr); 4600f241638SMatthias Springer 461*2ca887deSMatthias Springer maybeApplyPassLabel(builder, newXfer.op, options.targetRank); 4620f241638SMatthias Springer 4630f241638SMatthias Springer return newXfer; 4640f241638SMatthias Springer } 4650f241638SMatthias Springer 4660f241638SMatthias Springer /// Handle out-of-bounds accesses on the to-be-unpacked dimension. 4670f241638SMatthias Springer static void handleOutOfBoundsDim(OpBuilder &builder, TransferWriteOp xferOp, 4680f241638SMatthias Springer Value buffer, Value iv) {} 4690f241638SMatthias Springer 4700f241638SMatthias Springer /// Cleanup after rewriting the op. 4710f241638SMatthias Springer static void cleanup(PatternRewriter &rewriter, TransferWriteOp xferOp) { 4720f241638SMatthias Springer rewriter.eraseOp(xferOp); 4730f241638SMatthias Springer } 4740f241638SMatthias Springer }; 4750f241638SMatthias Springer 4760f241638SMatthias Springer template <typename OpTy> 477*2ca887deSMatthias Springer LogicalResult checkPrepareXferOp(OpTy xferOp, unsigned targetRank) { 4780f241638SMatthias Springer if (xferOp->hasAttr(kPassLabel)) 4790f241638SMatthias Springer return failure(); 480*2ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= targetRank) 4810f241638SMatthias Springer return failure(); 4820f241638SMatthias Springer return success(); 4830f241638SMatthias Springer } 4840f241638SMatthias Springer 4850f241638SMatthias Springer /// Prepare a TransferReadOp for progressive lowering. 4860f241638SMatthias Springer /// 4870f241638SMatthias Springer /// 1. Allocate a temporary buffer. 4880f241638SMatthias Springer /// 2. Label the TransferReadOp, marking it eligible for progressive lowering. 4890f241638SMatthias Springer /// 3. Store the result of the TransferReadOp into the temporary buffer. 4900f241638SMatthias Springer /// 4. Load the result from the temporary buffer and replace all uses of the 4910f241638SMatthias Springer /// original TransferReadOp with this load. 4920f241638SMatthias Springer /// 4930f241638SMatthias Springer /// E.g.: 4940f241638SMatthias Springer /// ``` 4950f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %cst 4960f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 4970f241638SMatthias Springer /// ``` 4980f241638SMatthias Springer /// is rewritten to: 4990f241638SMatthias Springer /// ``` 5000f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 5010f241638SMatthias Springer /// %1 = vector.transfer_read %A[%a, %b, %c], %cst 5020f241638SMatthias Springer /// { __vector_to_scf_lowering__ } : vector<5x4xf32>, memref<?x?x?xf32> 5030f241638SMatthias Springer /// memref.store %1, %0[] : memref<vector<5x4xf32>> 5040f241638SMatthias Springer /// %vec = memref.load %0[] : memref<vector<5x4xf32>> 5050f241638SMatthias Springer /// ``` 5060f241638SMatthias Springer /// 5070f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 508*2ca887deSMatthias Springer struct PrepareTransferReadConversion 509*2ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 510*2ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 5110f241638SMatthias Springer 5120f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 5130f241638SMatthias Springer PatternRewriter &rewriter) const override { 514*2ca887deSMatthias Springer if (checkPrepareXferOp(xferOp, options.targetRank).failed()) 5150f241638SMatthias Springer return failure(); 5160f241638SMatthias Springer 5170f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 5180f241638SMatthias Springer auto buffers = allocBuffers(xferOp); 5190f241638SMatthias Springer auto *newXfer = rewriter.clone(*xferOp.getOperation()); 5200f241638SMatthias Springer newXfer->setAttr(kPassLabel, rewriter.getUnitAttr()); 5210f241638SMatthias Springer if (xferOp.mask()) { 5220f241638SMatthias Springer dyn_cast<TransferReadOp>(newXfer).maskMutable().assign( 5230f241638SMatthias Springer buffers.maskBuffer); 5240f241638SMatthias Springer } 5250f241638SMatthias Springer 5260f241638SMatthias Springer memref_store(newXfer->getResult(0), buffers.dataBuffer); 5270f241638SMatthias Springer rewriter.replaceOpWithNewOp<memref::LoadOp>(xferOp, buffers.dataBuffer); 5284ead2cf7SAlex Zinenko 5294ead2cf7SAlex Zinenko return success(); 5304ead2cf7SAlex Zinenko } 5310f241638SMatthias Springer }; 5320f241638SMatthias Springer 5330f241638SMatthias Springer /// Prepare a TransferWriteOp for progressive lowering. 5340f241638SMatthias Springer /// 5350f241638SMatthias Springer /// 1. Allocate a temporary buffer. 5360f241638SMatthias Springer /// 2. Store the vector into the buffer. 5370f241638SMatthias Springer /// 3. Load the vector from the buffer again. 5380f241638SMatthias Springer /// 4. Use the loaded vector as a TransferWriteOp operand and label the op, 5390f241638SMatthias Springer /// marking it eligible for progressive lowering via TransferOpConversion. 5400f241638SMatthias Springer /// 5410f241638SMatthias Springer /// E.g.: 5420f241638SMatthias Springer /// ``` 5430f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 5440f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 5450f241638SMatthias Springer /// ``` 5460f241638SMatthias Springer /// is rewritten to: 5470f241638SMatthias Springer /// ``` 5480f241638SMatthias Springer /// %0 = memref.alloca() : memref<vector<5x4xf32>> 5490f241638SMatthias Springer /// memref.store %vec, %0[] : memref<vector<5x4xf32>> 5500f241638SMatthias Springer /// %1 = memref.load %0[] : memref<vector<5x4xf32>> 5510f241638SMatthias Springer /// vector.transfer_write %1, %A[%a, %b, %c] { __vector_to_scf_lowering__ } 5520f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 5530f241638SMatthias Springer /// ``` 5540f241638SMatthias Springer /// 5550f241638SMatthias Springer /// Note: A second temporary buffer may be allocated for the `mask` operand. 5560f241638SMatthias Springer struct PrepareTransferWriteConversion 557*2ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 558*2ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 5590f241638SMatthias Springer 5600f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 5610f241638SMatthias Springer PatternRewriter &rewriter) const override { 562*2ca887deSMatthias Springer if (checkPrepareXferOp(xferOp, options.targetRank).failed()) 5630f241638SMatthias Springer return failure(); 5640f241638SMatthias Springer 5650f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 5660f241638SMatthias Springer auto buffers = allocBuffers(xferOp); 5670f241638SMatthias Springer memref_store(xferOp.vector(), buffers.dataBuffer); 5680f241638SMatthias Springer auto loadedVec = memref_load(buffers.dataBuffer); 5690f241638SMatthias Springer rewriter.updateRootInPlace(xferOp, [&]() { 5700f241638SMatthias Springer xferOp.vectorMutable().assign(loadedVec); 5710f241638SMatthias Springer xferOp->setAttr(kPassLabel, rewriter.getUnitAttr()); 5720f241638SMatthias Springer }); 5730f241638SMatthias Springer 5740f241638SMatthias Springer if (xferOp.mask()) { 5750f241638SMatthias Springer rewriter.updateRootInPlace( 5760f241638SMatthias Springer xferOp, [&]() { xferOp.maskMutable().assign(buffers.maskBuffer); }); 5770f241638SMatthias Springer } 5780f241638SMatthias Springer 5790f241638SMatthias Springer return success(); 5800f241638SMatthias Springer } 5810f241638SMatthias Springer }; 5820f241638SMatthias Springer 5830f241638SMatthias Springer /// Progressive lowering of vector transfer ops: Unpack one dimension. 5840f241638SMatthias Springer /// 5850f241638SMatthias Springer /// 1. Unpack one dimension from the current buffer type and cast the buffer 5860f241638SMatthias Springer /// to that new type. E.g.: 5870f241638SMatthias Springer /// ``` 5880f241638SMatthias Springer /// %vec = memref.load %0[%1] : memref<5xvector<4x3xf32>> 5890f241638SMatthias Springer /// vector.transfer_write %vec ... 5900f241638SMatthias Springer /// ``` 5910f241638SMatthias Springer /// The following cast is generated: 5920f241638SMatthias Springer /// ``` 5930f241638SMatthias Springer /// %casted = vector.type_cast %0 5940f241638SMatthias Springer /// : memref<5xvector<4x3xf32>> to memref<5x4xvector<3xf32>> 5950f241638SMatthias Springer /// ``` 5960f241638SMatthias Springer /// 2. Generate a for loop and rewrite the transfer op according to the 5970f241638SMatthias Springer /// corresponding Strategy<OpTy>. If the to-be-unpacked dimension can be 5980f241638SMatthias Springer /// out-of-bounds, generate an if-check and handle both cases separately. 5990f241638SMatthias Springer /// 3. Clean up according to the corresponding Strategy<OpTy>. 6000f241638SMatthias Springer template <typename OpTy> 601*2ca887deSMatthias Springer struct TransferOpConversion : public VectorToSCFPattern<OpTy> { 602*2ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 6030f241638SMatthias Springer 6040f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 6050f241638SMatthias Springer PatternRewriter &rewriter) const override { 6060f241638SMatthias Springer if (!xferOp->hasAttr(kPassLabel)) 6070f241638SMatthias Springer return failure(); 6080f241638SMatthias Springer 6090f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 6100f241638SMatthias Springer 6110f241638SMatthias Springer // Find and cast data buffer. How the buffer can be found depends on OpTy. 6120f241638SMatthias Springer auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp); 6130f241638SMatthias Springer auto dataBufferType = dataBuffer.getType().template dyn_cast<MemRefType>(); 6140f241638SMatthias Springer auto castedDataType = unpackOneDim(dataBufferType); 6150f241638SMatthias Springer auto castedDataBuffer = vector_type_cast(castedDataType, dataBuffer); 6160f241638SMatthias Springer 6170f241638SMatthias Springer // If the xferOp has a mask: Find and cast mask buffer. 6180f241638SMatthias Springer Value castedMaskBuffer; 6190f241638SMatthias Springer if (xferOp.mask()) { 6200f241638SMatthias Springer auto maskBuffer = getMaskBuffer(xferOp); 6210f241638SMatthias Springer auto maskBufferType = 6220f241638SMatthias Springer maskBuffer.getType().template dyn_cast<MemRefType>(); 6230f241638SMatthias Springer if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) { 6240f241638SMatthias Springer // Do not unpack a dimension of the mask, if: 6250f241638SMatthias Springer // * To-be-unpacked transfer op dimension is a broadcast. 6260f241638SMatthias Springer // * Mask is 1D, i.e., the mask cannot be further unpacked. 6270f241638SMatthias Springer // (That means that all remaining dimensions of the transfer op must 6280f241638SMatthias Springer // be broadcasted.) 6290f241638SMatthias Springer castedMaskBuffer = maskBuffer; 6300f241638SMatthias Springer } else { 6310f241638SMatthias Springer auto castedMaskType = unpackOneDim(maskBufferType); 6320f241638SMatthias Springer castedMaskBuffer = vector_type_cast(castedMaskType, maskBuffer); 6330f241638SMatthias Springer } 6340f241638SMatthias Springer } 6350f241638SMatthias Springer 6360f241638SMatthias Springer // Loop bounds and step. 6370f241638SMatthias Springer auto lb = std_constant_index(0).value; 6380f241638SMatthias Springer auto ub = std_constant_index( 6390f241638SMatthias Springer castedDataType.getDimSize(castedDataType.getRank() - 1)) 6400f241638SMatthias Springer .value; 6410f241638SMatthias Springer auto step = std_constant_index(1).value; 6420f241638SMatthias Springer 6430f241638SMatthias Springer // Generate for loop. 6440f241638SMatthias Springer rewriter.create<scf::ForOp>( 6450f241638SMatthias Springer xferOp.getLoc(), lb, ub, step, ValueRange(), 6460f241638SMatthias Springer [&](OpBuilder &b, Location loc, Value iv, ValueRange /*loopState*/) { 6470f241638SMatthias Springer ScopedContext scope(b, loc); 6480f241638SMatthias Springer generateInBoundsCheck( 6490f241638SMatthias Springer xferOp, iv, b, unpackedDim(xferOp), 6500f241638SMatthias Springer /*inBoundsCase=*/ 6510f241638SMatthias Springer [&](OpBuilder &b, Location /*loc*/) { 6520f241638SMatthias Springer // Create new transfer op. 653*2ca887deSMatthias Springer OpTy newXfer = Strategy<OpTy>::rewriteOp( 654*2ca887deSMatthias Springer b, this->options, xferOp, castedDataBuffer, iv); 6550f241638SMatthias Springer 6560f241638SMatthias Springer // If old transfer op has a mask: Set mask on new transfer op. 6570f241638SMatthias Springer // Special case: If the mask of the old transfer op is 1D and 6580f241638SMatthias Springer // the 6590f241638SMatthias Springer // unpacked dim is not a broadcast, no mask is 6600f241638SMatthias Springer // needed on the new transfer op. 6610f241638SMatthias Springer if (xferOp.mask() && (xferOp.isBroadcastDim(0) || 6620f241638SMatthias Springer xferOp.getMaskType().getRank() > 1)) { 6630f241638SMatthias Springer OpBuilder::InsertionGuard guard(b); 6640f241638SMatthias Springer b.setInsertionPoint(newXfer); // Insert load before newXfer. 6650f241638SMatthias Springer 6660f241638SMatthias Springer SmallVector<Value, 8> loadIndices; 6670f241638SMatthias Springer Strategy<OpTy>::getBufferIndices(xferOp, loadIndices); 6680f241638SMatthias Springer // In case of broadcast: Use same indices to load from memref 6690f241638SMatthias Springer // as before. 6700f241638SMatthias Springer if (!xferOp.isBroadcastDim(0)) 6710f241638SMatthias Springer loadIndices.push_back(iv); 6720f241638SMatthias Springer 6730f241638SMatthias Springer auto mask = memref_load(castedMaskBuffer, loadIndices); 6740f241638SMatthias Springer rewriter.updateRootInPlace( 6750f241638SMatthias Springer newXfer, [&]() { newXfer.maskMutable().assign(mask); }); 6760f241638SMatthias Springer } 6770f241638SMatthias Springer }, 6780f241638SMatthias Springer /*outOfBoundsCase=*/ 6790f241638SMatthias Springer [&](OpBuilder &b, Location /*loc*/) { 6800f241638SMatthias Springer Strategy<OpTy>::handleOutOfBoundsDim(b, xferOp, 6810f241638SMatthias Springer castedDataBuffer, iv); 6820f241638SMatthias Springer }); 6830f241638SMatthias Springer b.create<scf::YieldOp>(loc); 6840f241638SMatthias Springer }); 6850f241638SMatthias Springer 6860f241638SMatthias Springer Strategy<OpTy>::cleanup(rewriter, xferOp); 6870f241638SMatthias Springer return success(); 6880f241638SMatthias Springer } 6890f241638SMatthias Springer }; 6900f241638SMatthias Springer 6910f241638SMatthias Springer /// If the original transfer op has a mask, compute the mask of the new transfer 6920f241638SMatthias Springer /// op (for the current iteration `i`) and assign it. 6930f241638SMatthias Springer template <typename OpTy> 6940f241638SMatthias Springer static void maybeAssignMask(OpBuilder &builder, OpTy xferOp, OpTy newXferOp, 6950f241638SMatthias Springer int64_t i) { 6960f241638SMatthias Springer if (!xferOp.mask()) 6970f241638SMatthias Springer return; 6980f241638SMatthias Springer 6990f241638SMatthias Springer if (xferOp.isBroadcastDim(0)) { 7000f241638SMatthias Springer // To-be-unpacked dimension is a broadcast, which does not have a 7010f241638SMatthias Springer // corresponding mask dimension. Mask attribute remains unchanged. 7020f241638SMatthias Springer newXferOp.maskMutable().assign(xferOp.mask()); 7030f241638SMatthias Springer return; 7040f241638SMatthias Springer } 7050f241638SMatthias Springer 7060f241638SMatthias Springer if (xferOp.getMaskType().getRank() > 1) { 7070f241638SMatthias Springer // Unpack one dimension of the mask. 7080f241638SMatthias Springer OpBuilder::InsertionGuard guard(builder); 7090f241638SMatthias Springer builder.setInsertionPoint(newXferOp); // Insert load before newXfer. 7100f241638SMatthias Springer 7110f241638SMatthias Springer llvm::SmallVector<int64_t, 1> indices({i}); 7120f241638SMatthias Springer auto newMask = vector_extract(xferOp.mask(), indices).value; 7130f241638SMatthias Springer newXferOp.maskMutable().assign(newMask); 7140f241638SMatthias Springer } 7150f241638SMatthias Springer 7160f241638SMatthias Springer // If we end up here: The mask of the old transfer op is 1D and the unpacked 7170f241638SMatthias Springer // dim is not a broadcast, so no mask is needed on the new transfer op. 7180f241638SMatthias Springer // `generateInBoundsCheck` will have evaluated the mask already. 7190f241638SMatthias Springer } 7200f241638SMatthias Springer 7210f241638SMatthias Springer /// Progressive lowering of vector TransferReadOp with unrolling: Unpack one 7220f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferReadOp>, but no 7230f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 7240f241638SMatthias Springer /// 7250f241638SMatthias Springer /// ``` 7260f241638SMatthias Springer /// E.g.: 7270f241638SMatthias Springer /// ``` 7280f241638SMatthias Springer /// %vec = vector.transfer_read %A[%a, %b, %c], %padding 7290f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<5x4xf32> 7300f241638SMatthias Springer /// ``` 7310f241638SMatthias Springer /// is rewritten to IR such as (simplified): 7320f241638SMatthias Springer /// ``` 7330f241638SMatthias Springer /// %v_init = splat %padding : vector<5x4xf32> 7340f241638SMatthias Springer /// %tmp0 = vector.transfer_read %A[%a, %b, %c], %padding 7350f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 7360f241638SMatthias Springer /// %v0 = vector.insert %tmp0, %v_init[0] : vector<4xf32> into vector<5x4xf32> 7370f241638SMatthias Springer /// %tmp1 = vector.transfer_read %A[%a, %b + 1, %c], %padding 7380f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 7390f241638SMatthias Springer /// %v1 = vector.insert %tmp1, %v0[1] : vector<4xf32> into vector<5x4xf32> 7400f241638SMatthias Springer /// ... 7410f241638SMatthias Springer /// %tmp4 = vector.transfer_read %A[%a, %b + 4, %c], %padding 7420f241638SMatthias Springer /// : memref<?x?x?xf32>, vector<4xf32> 7430f241638SMatthias Springer /// %vec = vector.insert %tmp1, %v3[4] : vector<4xf32> into vector<5x4xf32> 7440f241638SMatthias Springer /// ``` 7450f241638SMatthias Springer /// 7460f241638SMatthias Springer /// Note: As an optimization, if the result of the original TransferReadOp 7470f241638SMatthias Springer /// was directly inserted into another vector, no new %v_init vector is created. 7480f241638SMatthias Springer /// Instead, the new TransferReadOp results are inserted into that vector. 749*2ca887deSMatthias Springer struct UnrollTransferReadConversion 750*2ca887deSMatthias Springer : public VectorToSCFPattern<TransferReadOp> { 751*2ca887deSMatthias Springer using VectorToSCFPattern<TransferReadOp>::VectorToSCFPattern; 7520f241638SMatthias Springer 7530f241638SMatthias Springer /// Return the vector into which the newly created TransferReadOp results 7540f241638SMatthias Springer /// are inserted. 7550f241638SMatthias Springer Value getResultVector(TransferReadOp xferOp, 7560f241638SMatthias Springer PatternRewriter &rewriter) const { 7570f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) 7580f241638SMatthias Springer return insertOp.dest(); 7590f241638SMatthias Springer return std_splat(xferOp.getVectorType(), xferOp.padding()).value; 7600f241638SMatthias Springer } 7610f241638SMatthias Springer 7620f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 7630f241638SMatthias Springer /// vector::InsertOp, return that operation. 7640f241638SMatthias Springer vector::InsertOp getInsertOp(TransferReadOp xferOp) const { 7650f241638SMatthias Springer if (xferOp->hasOneUse()) { 7660f241638SMatthias Springer Operation *xferOpUser = *xferOp->getUsers().begin(); 7670f241638SMatthias Springer if (auto insertOp = dyn_cast<vector::InsertOp>(xferOpUser)) 7680f241638SMatthias Springer return insertOp; 7690f241638SMatthias Springer } 7700f241638SMatthias Springer 7710f241638SMatthias Springer return vector::InsertOp(); 7720f241638SMatthias Springer } 7730f241638SMatthias Springer 7740f241638SMatthias Springer /// If the result of the TransferReadOp has exactly one user, which is a 7750f241638SMatthias Springer /// vector::InsertOp, return that operation's indices. 7760f241638SMatthias Springer void getInsertionIndices(TransferReadOp xferOp, 7770f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const { 7780f241638SMatthias Springer if (auto insertOp = getInsertOp(xferOp)) { 7790f241638SMatthias Springer llvm::for_each(insertOp.position(), [&](Attribute attr) { 7800f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 7810f241638SMatthias Springer }); 7820f241638SMatthias Springer } 7830f241638SMatthias Springer } 7840f241638SMatthias Springer 7850f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 7860f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 7870f241638SMatthias Springer LogicalResult matchAndRewrite(TransferReadOp xferOp, 7880f241638SMatthias Springer PatternRewriter &rewriter) const override { 789*2ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 7900f241638SMatthias Springer return failure(); 7910f241638SMatthias Springer 7920f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 7930f241638SMatthias Springer auto insertOp = getInsertOp(xferOp); 7940f241638SMatthias Springer auto vec = getResultVector(xferOp, rewriter); 7950f241638SMatthias Springer auto vecType = vec.getType().dyn_cast<VectorType>(); 7960f241638SMatthias Springer auto xferVecType = xferOp.getVectorType(); 7970f241638SMatthias Springer auto newXferVecType = VectorType::get(xferVecType.getShape().drop_front(), 7980f241638SMatthias Springer xferVecType.getElementType()); 7990f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0]; 8000f241638SMatthias Springer 8010f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 8020f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 8030f241638SMatthias Springer Value iv = std_constant_index(i); 8040f241638SMatthias Springer 8050f241638SMatthias Springer vec = generateInBoundsCheck( 8060f241638SMatthias Springer xferOp, iv, rewriter, unpackedDim(xferOp), TypeRange(vecType), 8070f241638SMatthias Springer /*inBoundsCase=*/ 8080f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 8090f241638SMatthias Springer ScopedContext scope(b, loc); 8100f241638SMatthias Springer 8110f241638SMatthias Springer // Indices for the new transfer op. 8120f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 8130f241638SMatthias Springer getXferIndices(xferOp, iv, xferIndices); 8140f241638SMatthias Springer 8150f241638SMatthias Springer // Indices for the new vector.insert op. 8160f241638SMatthias Springer SmallVector<int64_t, 8> insertionIndices; 8170f241638SMatthias Springer getInsertionIndices(xferOp, insertionIndices); 8180f241638SMatthias Springer insertionIndices.push_back(i); 8190f241638SMatthias Springer 8200f241638SMatthias Springer auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 8210f241638SMatthias Springer auto newXferOpVal = 8220f241638SMatthias Springer vector_transfer_read( 8230f241638SMatthias Springer newXferVecType, xferOp.source(), xferIndices, 8240f241638SMatthias Springer AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), 8250f241638SMatthias Springer xferOp.padding(), Value(), inBoundsAttr) 8260f241638SMatthias Springer .value; 8270f241638SMatthias Springer auto newXferOp = 8280f241638SMatthias Springer dyn_cast<TransferReadOp>(newXferOpVal.getDefiningOp()); 8290f241638SMatthias Springer 8300f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 8310f241638SMatthias Springer 8320f241638SMatthias Springer return vector_insert(newXferOp, vec, insertionIndices).value; 8330f241638SMatthias Springer }, 8340f241638SMatthias Springer /*outOfBoundsCase=*/ 8350f241638SMatthias Springer [&](OpBuilder &b, Location loc) { 8360f241638SMatthias Springer // Loop through original (unmodified) vector. 8370f241638SMatthias Springer return vec; 8380f241638SMatthias Springer }); 8390f241638SMatthias Springer } 8400f241638SMatthias Springer 8410f241638SMatthias Springer if (insertOp) { 8420f241638SMatthias Springer // Rewrite single user of the old TransferReadOp, which was an InsertOp. 8430f241638SMatthias Springer rewriter.replaceOp(insertOp, vec); 8440f241638SMatthias Springer rewriter.eraseOp(xferOp); 8450f241638SMatthias Springer } else { 8460f241638SMatthias Springer rewriter.replaceOp(xferOp, vec); 8470f241638SMatthias Springer } 8480f241638SMatthias Springer 8490f241638SMatthias Springer return success(); 8500f241638SMatthias Springer } 8510f241638SMatthias Springer }; 8520f241638SMatthias Springer 8530f241638SMatthias Springer /// Progressive lowering of vector TransferWriteOp with unrolling: Unpack one 8540f241638SMatthias Springer /// dimension. This is similar to TransferOpConversion<TransferWriteOp>, but no 8550f241638SMatthias Springer /// memref buffer is allocated and the SCF loop is fully unrolled. 8560f241638SMatthias Springer /// 8570f241638SMatthias Springer /// ``` 8580f241638SMatthias Springer /// E.g.: 8590f241638SMatthias Springer /// ``` 8600f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b, %c] 8610f241638SMatthias Springer /// : vector<5x4xf32>, memref<?x?x?xf32> 8620f241638SMatthias Springer /// ``` 8630f241638SMatthias Springer /// is rewritten to IR such as (simplified): 8640f241638SMatthias Springer /// ``` 8650f241638SMatthias Springer /// %v0 = vector.extract %vec[0] : vector<5x4xf32> 8660f241638SMatthias Springer /// vector.transfer_write %v0, %A[%a, %b, %c] : vector<4xf32>, memref<...> 8670f241638SMatthias Springer /// %v1 = vector.extract %vec[1] : vector<5x4xf32> 8680f241638SMatthias Springer /// vector.transfer_write %v1, %A[%a, %b + 1, %c] : vector<4xf32>, memref<...> 8690f241638SMatthias Springer /// ... 8700f241638SMatthias Springer /// %v4 = vector.extract %vec[4] : vector<5x4xf32> 8710f241638SMatthias Springer /// vector.transfer_write %v4, %A[%a, %b + 4, %c] : vector<4xf32>, memref<...> 8720f241638SMatthias Springer /// ``` 8730f241638SMatthias Springer /// 8740f241638SMatthias Springer /// Note: As an optimization, if the vector of the original TransferWriteOp 8750f241638SMatthias Springer /// was directly extracted from another vector via an ExtractOp `a`, extract 8760f241638SMatthias Springer /// the vectors for the newly generated TransferWriteOps from `a`'s input. By 8770f241638SMatthias Springer /// doing so, `a` may become dead, and the number of ExtractOps generated during 8780f241638SMatthias Springer /// recursive application of this pattern will be minimal. 8790f241638SMatthias Springer struct UnrollTransferWriteConversion 880*2ca887deSMatthias Springer : public VectorToSCFPattern<TransferWriteOp> { 881*2ca887deSMatthias Springer using VectorToSCFPattern<TransferWriteOp>::VectorToSCFPattern; 8820f241638SMatthias Springer 8830f241638SMatthias Springer /// Return the vector from which newly generated ExtracOps will extract. 8840f241638SMatthias Springer Value getDataVector(TransferWriteOp xferOp) const { 8850f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) 8860f241638SMatthias Springer return extractOp.vector(); 8870f241638SMatthias Springer return xferOp.vector(); 8880f241638SMatthias Springer } 8890f241638SMatthias Springer 8900f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return it. 8910f241638SMatthias Springer vector::ExtractOp getExtractOp(TransferWriteOp xferOp) const { 8920f241638SMatthias Springer if (auto *op = xferOp.vector().getDefiningOp()) 8930f241638SMatthias Springer return dyn_cast<vector::ExtractOp>(op); 8940f241638SMatthias Springer return vector::ExtractOp(); 8950f241638SMatthias Springer } 8960f241638SMatthias Springer 8970f241638SMatthias Springer /// If the input of the given TransferWriteOp is an ExtractOp, return its 8980f241638SMatthias Springer /// indices. 8990f241638SMatthias Springer void getExtractionIndices(TransferWriteOp xferOp, 9000f241638SMatthias Springer SmallVector<int64_t, 8> &indices) const { 9010f241638SMatthias Springer if (auto extractOp = getExtractOp(xferOp)) { 9020f241638SMatthias Springer llvm::for_each(extractOp.position(), [&](Attribute attr) { 9030f241638SMatthias Springer indices.push_back(attr.dyn_cast<IntegerAttr>().getInt()); 9040f241638SMatthias Springer }); 9050f241638SMatthias Springer } 9060f241638SMatthias Springer } 9070f241638SMatthias Springer 9080f241638SMatthias Springer /// Rewrite the op: Unpack one dimension. Can handle masks, out-of-bounds 9090f241638SMatthias Springer /// accesses, and broadcasts and transposes in permutation maps. 9100f241638SMatthias Springer LogicalResult matchAndRewrite(TransferWriteOp xferOp, 9110f241638SMatthias Springer PatternRewriter &rewriter) const override { 912*2ca887deSMatthias Springer if (xferOp.getVectorType().getRank() <= options.targetRank) 9130f241638SMatthias Springer return failure(); 9140f241638SMatthias Springer 9150f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 9160f241638SMatthias Springer auto vec = getDataVector(xferOp); 9170f241638SMatthias Springer auto xferVecType = xferOp.getVectorType(); 9180f241638SMatthias Springer int64_t dimSize = xferVecType.getShape()[0]; 9190f241638SMatthias Springer 9200f241638SMatthias Springer // Generate fully unrolled loop of transfer ops. 9210f241638SMatthias Springer for (int64_t i = 0; i < dimSize; ++i) { 9220f241638SMatthias Springer Value iv = std_constant_index(i); 9230f241638SMatthias Springer 9240f241638SMatthias Springer generateInBoundsCheck( 9250f241638SMatthias Springer xferOp, iv, rewriter, unpackedDim(xferOp), 9260f241638SMatthias Springer /*inBoundsCase=*/[&](OpBuilder &b, Location loc) { 9270f241638SMatthias Springer ScopedContext scope(b, loc); 9280f241638SMatthias Springer 9290f241638SMatthias Springer // Indices for the new transfer op. 9300f241638SMatthias Springer SmallVector<Value, 8> xferIndices; 9310f241638SMatthias Springer getXferIndices(xferOp, iv, xferIndices); 9320f241638SMatthias Springer 9330f241638SMatthias Springer // Indices for the new vector.extract op. 9340f241638SMatthias Springer SmallVector<int64_t, 8> extractionIndices; 9350f241638SMatthias Springer getExtractionIndices(xferOp, extractionIndices); 9360f241638SMatthias Springer extractionIndices.push_back(i); 9370f241638SMatthias Springer 9380f241638SMatthias Springer auto extracted = vector_extract(vec, extractionIndices).value; 9390f241638SMatthias Springer auto inBoundsAttr = dropFirstElem(b, xferOp.in_boundsAttr()); 9400f241638SMatthias Springer 9410f241638SMatthias Springer auto newXferOp = 9420f241638SMatthias Springer vector_transfer_write( 9430f241638SMatthias Springer Type(), extracted, xferOp.source(), xferIndices, 9440f241638SMatthias Springer AffineMapAttr::get(unpackedPermutationMap(xferOp, b)), 9450f241638SMatthias Springer Value(), inBoundsAttr) 9460f241638SMatthias Springer .op; 9470f241638SMatthias Springer 9480f241638SMatthias Springer maybeAssignMask(b, xferOp, newXferOp, i); 9490f241638SMatthias Springer }); 9500f241638SMatthias Springer } 9510f241638SMatthias Springer 9520f241638SMatthias Springer rewriter.eraseOp(xferOp); 9530f241638SMatthias Springer return success(); 9540f241638SMatthias Springer } 9550f241638SMatthias Springer }; 9560f241638SMatthias Springer 9570f241638SMatthias Springer /// Compute the indices into the memref for the LoadOp/StoreOp generated as 9580f241638SMatthias Springer /// part of TransferOp1dConversion. Return the memref dimension on which 9590f241638SMatthias Springer /// the transfer is operating. A return value of None indicates a broadcast. 9600f241638SMatthias Springer template <typename OpTy> 9610f241638SMatthias Springer static Optional<int64_t> 9620f241638SMatthias Springer get1dMemrefIndices(OpTy xferOp, Value iv, 9630f241638SMatthias Springer SmallVector<Value, 8> &memrefIndices) { 9640f241638SMatthias Springer auto indices = xferOp.indices(); 9650f241638SMatthias Springer auto map = xferOp.permutation_map(); 9660f241638SMatthias Springer 9670f241638SMatthias Springer memrefIndices.append(indices.begin(), indices.end()); 9680f241638SMatthias Springer assert(map.getNumResults() == 1 && 9690f241638SMatthias Springer "Expected 1 permutation map result for 1D transfer"); 9700f241638SMatthias Springer if (auto expr = map.getResult(0).template dyn_cast<AffineDimExpr>()) { 9710f241638SMatthias Springer auto dim = expr.getPosition(); 9720f241638SMatthias Springer using edsc::op::operator+; 9730f241638SMatthias Springer memrefIndices[dim] = memrefIndices[dim] + iv; 9740f241638SMatthias Springer return dim; 9750f241638SMatthias Springer } 9760f241638SMatthias Springer 9770f241638SMatthias Springer assert(xferOp.isBroadcastDim(0) && 9780f241638SMatthias Springer "Expected AffineDimExpr or AffineConstantExpr"); 9790f241638SMatthias Springer return None; 9800f241638SMatthias Springer } 9810f241638SMatthias Springer 9820f241638SMatthias Springer /// Codegen strategy for TransferOp1dConversion, depending on the 9830f241638SMatthias Springer /// operation. 9840f241638SMatthias Springer template <typename OpTy> 9850f241638SMatthias Springer struct Strategy1d; 9860f241638SMatthias Springer 9870f241638SMatthias Springer /// Codegen strategy for TransferReadOp. 9880f241638SMatthias Springer template <> 9890f241638SMatthias Springer struct Strategy1d<TransferReadOp> { 9900f241638SMatthias Springer static void generateForLoopBody(OpBuilder &builder, Location loc, 9910f241638SMatthias Springer TransferReadOp xferOp, Value iv, 9920f241638SMatthias Springer ValueRange loopState) { 9930f241638SMatthias Springer SmallVector<Value, 8> indices; 9940f241638SMatthias Springer auto dim = get1dMemrefIndices(xferOp, iv, indices); 9950f241638SMatthias Springer auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 9960f241638SMatthias Springer auto vec = loopState[0]; 9970f241638SMatthias Springer 9980f241638SMatthias Springer // In case of out-of-bounds access, leave `vec` as is (was initialized with 9990f241638SMatthias Springer // padding value). 10000f241638SMatthias Springer auto nextVec = generateInBoundsCheck( 10010f241638SMatthias Springer xferOp, iv, builder, dim, TypeRange(xferOp.getVectorType()), 10020f241638SMatthias Springer /*inBoundsCase=*/ 10030f241638SMatthias Springer [&](OpBuilder & /*b*/, Location loc) { 10040f241638SMatthias Springer auto val = memref_load(xferOp.source(), indices); 10050f241638SMatthias Springer return vector_insert_element(val, vec, ivI32.value).value; 10060f241638SMatthias Springer }, 10070f241638SMatthias Springer /*outOfBoundsCase=*/ 10080f241638SMatthias Springer [&](OpBuilder & /*b*/, Location loc) { return vec; }); 10090f241638SMatthias Springer builder.create<scf::YieldOp>(loc, nextVec); 10100f241638SMatthias Springer } 10110f241638SMatthias Springer 10120f241638SMatthias Springer static Value initialLoopState(TransferReadOp xferOp) { 10130f241638SMatthias Springer // Inititalize vector with padding value. 10140f241638SMatthias Springer return std_splat(xferOp.getVectorType(), xferOp.padding()).value; 10150f241638SMatthias Springer } 10160f241638SMatthias Springer }; 10170f241638SMatthias Springer 10180f241638SMatthias Springer /// Codegen strategy for TransferWriteOp. 10190f241638SMatthias Springer template <> 10200f241638SMatthias Springer struct Strategy1d<TransferWriteOp> { 10210f241638SMatthias Springer static void generateForLoopBody(OpBuilder &builder, Location loc, 10220f241638SMatthias Springer TransferWriteOp xferOp, Value iv, 10230f241638SMatthias Springer ValueRange /*loopState*/) { 10240f241638SMatthias Springer SmallVector<Value, 8> indices; 10250f241638SMatthias Springer auto dim = get1dMemrefIndices(xferOp, iv, indices); 10260f241638SMatthias Springer auto ivI32 = std_index_cast(IntegerType::get(builder.getContext(), 32), iv); 10270f241638SMatthias Springer 10280f241638SMatthias Springer // Nothing to do in case of out-of-bounds access. 10290f241638SMatthias Springer generateInBoundsCheck( 10300f241638SMatthias Springer xferOp, iv, builder, dim, 10310f241638SMatthias Springer /*inBoundsCase=*/[&](OpBuilder & /*b*/, Location loc) { 10320f241638SMatthias Springer auto val = vector_extract_element(xferOp.vector(), ivI32.value); 10330f241638SMatthias Springer memref_store(val, xferOp.source(), indices); 10340f241638SMatthias Springer }); 10350f241638SMatthias Springer builder.create<scf::YieldOp>(loc); 10360f241638SMatthias Springer } 10370f241638SMatthias Springer 10380f241638SMatthias Springer static Value initialLoopState(TransferWriteOp xferOp) { return Value(); } 10390f241638SMatthias Springer }; 10400f241638SMatthias Springer 10410f241638SMatthias Springer /// Return true if the last dimension of the MemRefType has unit stride. 10420f241638SMatthias Springer static bool isLastMemrefDimUnitStride(MemRefType type) { 10430f241638SMatthias Springer int64_t offset; 10440f241638SMatthias Springer SmallVector<int64_t, 4> strides; 10450f241638SMatthias Springer auto successStrides = getStridesAndOffset(type, strides, offset); 10460f241638SMatthias Springer return succeeded(successStrides) && strides.back() == 1; 10470f241638SMatthias Springer } 10480f241638SMatthias Springer 10490f241638SMatthias Springer /// Lower a 1D vector transfer op to SCF using scalar loads/stores. This is 10500f241638SMatthias Springer /// necessary in cases where a 1D vector transfer op cannot be lowered into 10510f241638SMatthias Springer /// vector load/stores due to non-unit strides or broadcasts: 10520f241638SMatthias Springer /// 10530f241638SMatthias Springer /// * Transfer dimension is not the last memref dimension 10540f241638SMatthias Springer /// * Transfer dimension is a broadcast (i.e., scalar load + broadcast) 10550f241638SMatthias Springer /// * Memref has a layout map with non-unit stride on the last dimension 10560f241638SMatthias Springer /// 10570f241638SMatthias Springer /// This pattern generates IR as follows: 10580f241638SMatthias Springer /// 10590f241638SMatthias Springer /// 1. Generate a for loop iterating over each vector element. 10600f241638SMatthias Springer /// 2. Inside the loop, generate a InsertElementOp or ExtractElementOp, 10610f241638SMatthias Springer /// depending on OpTy. 10620f241638SMatthias Springer /// 10630f241638SMatthias Springer /// TODO: In some cases (no masking, etc.), LLVM::MatrixColumnMajorLoadOp 10640f241638SMatthias Springer /// can be generated instead of TransferOp1dConversion. Add such a pattern 10650f241638SMatthias Springer /// to ConvertVectorToLLVM. 10660f241638SMatthias Springer /// 10670f241638SMatthias Springer /// E.g.: 10680f241638SMatthias Springer /// ``` 10690f241638SMatthias Springer /// vector.transfer_write %vec, %A[%a, %b] 10700f241638SMatthias Springer /// {permutation_map = affine_map<(d0, d1) -> (d0)>, in_bounds = [true]} 10710f241638SMatthias Springer /// : vector<9xf32>, memref<?x?xf32> 10720f241638SMatthias Springer /// ``` 10730f241638SMatthias Springer /// Is rewritten to approximately the following pseudo-IR: 10740f241638SMatthias Springer /// ``` 10750f241638SMatthias Springer /// for i = 0 to 9 { 10760f241638SMatthias Springer /// %t = vector.extractelement %vec[i] : vector<9xf32> 10770f241638SMatthias Springer /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> 10780f241638SMatthias Springer /// } 10790f241638SMatthias Springer /// ``` 10800f241638SMatthias Springer template <typename OpTy> 1081*2ca887deSMatthias Springer struct TransferOp1dConversion : public VectorToSCFPattern<OpTy> { 1082*2ca887deSMatthias Springer using VectorToSCFPattern<OpTy>::VectorToSCFPattern; 10830f241638SMatthias Springer 10840f241638SMatthias Springer LogicalResult matchAndRewrite(OpTy xferOp, 10850f241638SMatthias Springer PatternRewriter &rewriter) const override { 10860f241638SMatthias Springer ScopedContext scope(rewriter, xferOp.getLoc()); 10870f241638SMatthias Springer auto map = xferOp.permutation_map(); 10880f241638SMatthias Springer auto memRefType = xferOp.getShapedType().template dyn_cast<MemRefType>(); 10890f241638SMatthias Springer 10900f241638SMatthias Springer if (!memRefType) 10910f241638SMatthias Springer return failure(); 10920f241638SMatthias Springer if (xferOp.getVectorType().getRank() != 1) 10930f241638SMatthias Springer return failure(); 10940f241638SMatthias Springer if (map.isMinorIdentity() && isLastMemrefDimUnitStride(memRefType)) 10950f241638SMatthias Springer return failure(); // Handled by ConvertVectorToLLVM 10960f241638SMatthias Springer 10970f241638SMatthias Springer // Loop bounds, step, state... 10980f241638SMatthias Springer auto vecType = xferOp.getVectorType(); 10990f241638SMatthias Springer auto lb = std_constant_index(0); 11000f241638SMatthias Springer auto ub = std_constant_index(vecType.getDimSize(0)); 11010f241638SMatthias Springer auto step = std_constant_index(1); 11020f241638SMatthias Springer auto loopState = Strategy1d<OpTy>::initialLoopState(xferOp); 11030f241638SMatthias Springer 11040f241638SMatthias Springer // Generate for loop. 11050f241638SMatthias Springer rewriter.replaceOpWithNewOp<scf::ForOp>( 11060f241638SMatthias Springer xferOp, lb, ub, step, loopState ? ValueRange(loopState) : ValueRange(), 11070f241638SMatthias Springer [&](OpBuilder &builder, Location loc, Value iv, ValueRange loopState) { 11080f241638SMatthias Springer ScopedContext nestedScope(builder, loc); 11090f241638SMatthias Springer Strategy1d<OpTy>::generateForLoopBody(builder, loc, xferOp, iv, 11100f241638SMatthias Springer loopState); 11110f241638SMatthias Springer }); 11120f241638SMatthias Springer 11130f241638SMatthias Springer return success(); 11140f241638SMatthias Springer } 11150f241638SMatthias Springer }; 11164ead2cf7SAlex Zinenko 1117df63eedeSBenjamin Kramer } // namespace 1118df63eedeSBenjamin Kramer 111951d30c34SBenjamin Kramer namespace mlir { 112051d30c34SBenjamin Kramer 11213393cc4cSNicolas Vasilache void populateVectorToSCFConversionPatterns( 1122dc4e913bSChris Lattner RewritePatternSet &patterns, const VectorTransferToSCFOptions &options) { 11230f241638SMatthias Springer if (options.unroll) { 11240f241638SMatthias Springer patterns.add<UnrollTransferReadConversion, UnrollTransferWriteConversion>( 1125*2ca887deSMatthias Springer patterns.getContext(), options); 11260f241638SMatthias Springer } else { 11270f241638SMatthias Springer patterns.add<PrepareTransferReadConversion, PrepareTransferWriteConversion, 11280f241638SMatthias Springer TransferOpConversion<TransferReadOp>, 1129*2ca887deSMatthias Springer TransferOpConversion<TransferWriteOp>>(patterns.getContext(), 1130*2ca887deSMatthias Springer options); 11310f241638SMatthias Springer } 11320f241638SMatthias Springer 1133*2ca887deSMatthias Springer if (options.targetRank == 1) { 11340f241638SMatthias Springer patterns.add<TransferOp1dConversion<TransferReadOp>, 1135*2ca887deSMatthias Springer TransferOp1dConversion<TransferWriteOp>>(patterns.getContext(), 1136*2ca887deSMatthias Springer options); 11370f241638SMatthias Springer } 11384ead2cf7SAlex Zinenko } 11393393cc4cSNicolas Vasilache 11403393cc4cSNicolas Vasilache } // namespace mlir 11413393cc4cSNicolas Vasilache 11425f9e0466SNicolas Vasilache namespace { 11435f9e0466SNicolas Vasilache 11445f9e0466SNicolas Vasilache struct ConvertVectorToSCFPass 11455f9e0466SNicolas Vasilache : public ConvertVectorToSCFBase<ConvertVectorToSCFPass> { 11465f9e0466SNicolas Vasilache ConvertVectorToSCFPass() = default; 11475f9e0466SNicolas Vasilache ConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 11485f9e0466SNicolas Vasilache this->fullUnroll = options.unroll; 1149*2ca887deSMatthias Springer this->targetRank = options.targetRank; 11505f9e0466SNicolas Vasilache } 11515f9e0466SNicolas Vasilache 11525f9e0466SNicolas Vasilache void runOnFunction() override { 1153*2ca887deSMatthias Springer VectorTransferToSCFOptions options; 1154*2ca887deSMatthias Springer options.setUnroll(fullUnroll); 1155*2ca887deSMatthias Springer options.setTargetRank(targetRank); 1156*2ca887deSMatthias Springer 1157dc4e913bSChris Lattner RewritePatternSet patterns(getFunction().getContext()); 1158*2ca887deSMatthias Springer populateVectorToSCFConversionPatterns(patterns, options); 1159e21adfa3SRiver Riddle (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); 11605f9e0466SNicolas Vasilache } 11615f9e0466SNicolas Vasilache }; 11625f9e0466SNicolas Vasilache 11635f9e0466SNicolas Vasilache } // namespace 11645f9e0466SNicolas Vasilache 11655f9e0466SNicolas Vasilache std::unique_ptr<Pass> 11665f9e0466SNicolas Vasilache mlir::createConvertVectorToSCFPass(const VectorTransferToSCFOptions &options) { 11675f9e0466SNicolas Vasilache return std::make_unique<ConvertVectorToSCFPass>(options); 11685f9e0466SNicolas Vasilache } 1169