10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 20deeaacaSLei Zhang // 30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information. 50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60deeaacaSLei Zhang // 70deeaacaSLei Zhang //===----------------------------------------------------------------------===// 80deeaacaSLei Zhang // 90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into 100deeaacaSLei Zhang // loading/storing from/to the original memref. 110deeaacaSLei Zhang // 120deeaacaSLei Zhang //===----------------------------------------------------------------------===// 130deeaacaSLei Zhang 14fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 160deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h" 170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h" 180deeaacaSLei Zhang #include "mlir/Dialect/StandardOps/IR/Ops.h" 190deeaacaSLei Zhang #include "mlir/Dialect/Vector/VectorOps.h" 200deeaacaSLei Zhang #include "mlir/IR/BuiltinTypes.h" 210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 220deeaacaSLei Zhang 230deeaacaSLei Zhang using namespace mlir; 240deeaacaSLei Zhang 250deeaacaSLei Zhang //===----------------------------------------------------------------------===// 260deeaacaSLei Zhang // Utility functions 270deeaacaSLei Zhang //===----------------------------------------------------------------------===// 280deeaacaSLei Zhang 290deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result 300deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the 310deeaacaSLei Zhang /// subview op. For example 320deeaacaSLei Zhang /// 330deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32> 340deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 350deeaacaSLei Zhang /// memref<4x4xf32, offset=?, strides=[?, ?]> 360deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 370deeaacaSLei Zhang /// 380deeaacaSLei Zhang /// could be folded into 390deeaacaSLei Zhang /// 400deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 410deeaacaSLei Zhang /// memref<12x42xf32> 420deeaacaSLei Zhang static LogicalResult 430deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter, 440deeaacaSLei Zhang memref::SubViewOp subViewOp, ValueRange indices, 450deeaacaSLei Zhang SmallVectorImpl<Value> &sourceIndices) { 46fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 47fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 48fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 490deeaacaSLei Zhang 50fd15e2b8SMaheshRavishankar SmallVector<Value> useIndices; 51fd15e2b8SMaheshRavishankar // Check if this is rank-reducing case. Then for every unit-dim size add a 52fd15e2b8SMaheshRavishankar // zero to the indices. 53fd15e2b8SMaheshRavishankar unsigned resultDim = 0; 544cf9bf6cSMaheshRavishankar llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 554cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) { 564cf9bf6cSMaheshRavishankar if (unusedDims.count(dim)) 57a54f4eaeSMogball useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0)); 584cf9bf6cSMaheshRavishankar else 59fd15e2b8SMaheshRavishankar useIndices.push_back(indices[resultDim++]); 60fd15e2b8SMaheshRavishankar } 61fd15e2b8SMaheshRavishankar if (useIndices.size() != mixedOffsets.size()) 62fd15e2b8SMaheshRavishankar return failure(); 63fd15e2b8SMaheshRavishankar sourceIndices.resize(useIndices.size()); 64fd15e2b8SMaheshRavishankar for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 65fd15e2b8SMaheshRavishankar SmallVector<Value> dynamicOperands; 66fd15e2b8SMaheshRavishankar AffineExpr expr = rewriter.getAffineDimExpr(0); 67fd15e2b8SMaheshRavishankar unsigned numSymbols = 0; 68fd15e2b8SMaheshRavishankar dynamicOperands.push_back(useIndices[index]); 69fd15e2b8SMaheshRavishankar 70fd15e2b8SMaheshRavishankar // Multiply the stride; 71fd15e2b8SMaheshRavishankar if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 72fd15e2b8SMaheshRavishankar expr = expr * attr.cast<IntegerAttr>().getInt(); 73fd15e2b8SMaheshRavishankar } else { 74fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedStrides[index].get<Value>()); 75fd15e2b8SMaheshRavishankar expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 76fd15e2b8SMaheshRavishankar } 77fd15e2b8SMaheshRavishankar 78fd15e2b8SMaheshRavishankar // Add the offset. 79fd15e2b8SMaheshRavishankar if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 80fd15e2b8SMaheshRavishankar expr = expr + attr.cast<IntegerAttr>().getInt(); 81fd15e2b8SMaheshRavishankar } else { 82fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 83fd15e2b8SMaheshRavishankar expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 84fd15e2b8SMaheshRavishankar } 85fd15e2b8SMaheshRavishankar Location loc = subViewOp.getLoc(); 86fd15e2b8SMaheshRavishankar sourceIndices[index] = rewriter.create<AffineApplyOp>( 87fd15e2b8SMaheshRavishankar loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 880deeaacaSLei Zhang } 890deeaacaSLei Zhang return success(); 900deeaacaSLei Zhang } 910deeaacaSLei Zhang 920deeaacaSLei Zhang /// Helpers to access the memref operand for each op. 93*f8a2cd67SUday Bondhugula template <typename LoadOrStoreOpTy> 94*f8a2cd67SUday Bondhugula static Value getMemRefOperand(LoadOrStoreOpTy op) { 95*f8a2cd67SUday Bondhugula return op.memref(); 96*f8a2cd67SUday Bondhugula } 970deeaacaSLei Zhang 980deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 990deeaacaSLei Zhang 1000deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) { 1010deeaacaSLei Zhang return op.source(); 1020deeaacaSLei Zhang } 1030deeaacaSLei Zhang 1044cf9bf6cSMaheshRavishankar /// Given the permutation map of the original 1054cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the 1064cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it. 107c537a943SNicolas Vasilache static AffineMapAttr getPermutationMapAttr(MLIRContext *context, 1084cf9bf6cSMaheshRavishankar memref::SubViewOp subViewOp, 1094cf9bf6cSMaheshRavishankar AffineMap currPermutationMap) { 1104cf9bf6cSMaheshRavishankar llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1114cf9bf6cSMaheshRavishankar SmallVector<AffineExpr> exprs; 1124cf9bf6cSMaheshRavishankar int64_t sourceRank = subViewOp.getSourceType().getRank(); 1134cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 1144cf9bf6cSMaheshRavishankar if (unusedDims.count(dim)) 1154cf9bf6cSMaheshRavishankar continue; 116b12e4c17Sthomasraoux exprs.push_back(getAffineDimExpr(dim, context)); 1174cf9bf6cSMaheshRavishankar } 1184cf9bf6cSMaheshRavishankar auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); 119c537a943SNicolas Vasilache return AffineMapAttr::get( 120c537a943SNicolas Vasilache currPermutationMap.compose(resultDimToSourceDimMap)); 1214cf9bf6cSMaheshRavishankar } 1224cf9bf6cSMaheshRavishankar 1230deeaacaSLei Zhang //===----------------------------------------------------------------------===// 1240deeaacaSLei Zhang // Patterns 1250deeaacaSLei Zhang //===----------------------------------------------------------------------===// 1260deeaacaSLei Zhang 1270deeaacaSLei Zhang namespace { 1280deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation. 1290deeaacaSLei Zhang template <typename OpTy> 1300deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 1310deeaacaSLei Zhang public: 1320deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern; 1330deeaacaSLei Zhang 1340deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy loadOp, 1350deeaacaSLei Zhang PatternRewriter &rewriter) const override; 1360deeaacaSLei Zhang 1370deeaacaSLei Zhang private: 1380deeaacaSLei Zhang void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 1390deeaacaSLei Zhang ArrayRef<Value> sourceIndices, 1400deeaacaSLei Zhang PatternRewriter &rewriter) const; 1410deeaacaSLei Zhang }; 1420deeaacaSLei Zhang 1430deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation. 1440deeaacaSLei Zhang template <typename OpTy> 1450deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 1460deeaacaSLei Zhang public: 1470deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern; 1480deeaacaSLei Zhang 1490deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy storeOp, 1500deeaacaSLei Zhang PatternRewriter &rewriter) const override; 1510deeaacaSLei Zhang 1520deeaacaSLei Zhang private: 1530deeaacaSLei Zhang void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 1540deeaacaSLei Zhang ArrayRef<Value> sourceIndices, 1550deeaacaSLei Zhang PatternRewriter &rewriter) const; 1560deeaacaSLei Zhang }; 1570deeaacaSLei Zhang 158*f8a2cd67SUday Bondhugula template <typename LoadOpTy> 159*f8a2cd67SUday Bondhugula void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp( 160*f8a2cd67SUday Bondhugula LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices, 161*f8a2cd67SUday Bondhugula PatternRewriter &rewriter) const { 162*f8a2cd67SUday Bondhugula rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(), 1630deeaacaSLei Zhang sourceIndices); 1640deeaacaSLei Zhang } 1650deeaacaSLei Zhang 1660deeaacaSLei Zhang template <> 1670deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 168c537a943SNicolas Vasilache vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, 1690deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 170c537a943SNicolas Vasilache // TODO: support 0-d corner case. 171c537a943SNicolas Vasilache if (transferReadOp.getTransferRank() == 0) 172c537a943SNicolas Vasilache return; 1730deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 174c537a943SNicolas Vasilache transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), 175c537a943SNicolas Vasilache sourceIndices, 176c537a943SNicolas Vasilache getPermutationMapAttr(rewriter.getContext(), subViewOp, 177c537a943SNicolas Vasilache transferReadOp.permutation_map()), 178c537a943SNicolas Vasilache transferReadOp.padding(), 179c537a943SNicolas Vasilache /*mask=*/Value(), transferReadOp.in_boundsAttr()); 1800deeaacaSLei Zhang } 1810deeaacaSLei Zhang 182*f8a2cd67SUday Bondhugula template <typename StoreOpTy> 183*f8a2cd67SUday Bondhugula void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp( 184*f8a2cd67SUday Bondhugula StoreOpTy storeOp, memref::SubViewOp subViewOp, 1850deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 186*f8a2cd67SUday Bondhugula rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(), 187*f8a2cd67SUday Bondhugula subViewOp.source(), sourceIndices); 1880deeaacaSLei Zhang } 1890deeaacaSLei Zhang 1900deeaacaSLei Zhang template <> 1910deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 1920deeaacaSLei Zhang vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 1930deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 194c537a943SNicolas Vasilache // TODO: support 0-d corner case. 195c537a943SNicolas Vasilache if (transferWriteOp.getTransferRank() == 0) 196c537a943SNicolas Vasilache return; 1970deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 1980deeaacaSLei Zhang transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 1994cf9bf6cSMaheshRavishankar sourceIndices, 200c537a943SNicolas Vasilache getPermutationMapAttr(rewriter.getContext(), subViewOp, 2014cf9bf6cSMaheshRavishankar transferWriteOp.permutation_map()), 2020deeaacaSLei Zhang transferWriteOp.in_boundsAttr()); 2030deeaacaSLei Zhang } 2040deeaacaSLei Zhang } // namespace 2050deeaacaSLei Zhang 2060deeaacaSLei Zhang template <typename OpTy> 2070deeaacaSLei Zhang LogicalResult 2080deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 2090deeaacaSLei Zhang PatternRewriter &rewriter) const { 2100deeaacaSLei Zhang auto subViewOp = 2110deeaacaSLei Zhang getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 2120deeaacaSLei Zhang if (!subViewOp) 2130deeaacaSLei Zhang return failure(); 2140deeaacaSLei Zhang 2150deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices; 2160deeaacaSLei Zhang if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 2170deeaacaSLei Zhang loadOp.indices(), sourceIndices))) 2180deeaacaSLei Zhang return failure(); 2190deeaacaSLei Zhang 2200deeaacaSLei Zhang replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 2210deeaacaSLei Zhang return success(); 2220deeaacaSLei Zhang } 2230deeaacaSLei Zhang 2240deeaacaSLei Zhang template <typename OpTy> 2250deeaacaSLei Zhang LogicalResult 2260deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 2270deeaacaSLei Zhang PatternRewriter &rewriter) const { 2280deeaacaSLei Zhang auto subViewOp = 2290deeaacaSLei Zhang getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 2300deeaacaSLei Zhang if (!subViewOp) 2310deeaacaSLei Zhang return failure(); 2320deeaacaSLei Zhang 2330deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices; 2340deeaacaSLei Zhang if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 2350deeaacaSLei Zhang storeOp.indices(), sourceIndices))) 2360deeaacaSLei Zhang return failure(); 2370deeaacaSLei Zhang 2380deeaacaSLei Zhang replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 2390deeaacaSLei Zhang return success(); 2400deeaacaSLei Zhang } 2410deeaacaSLei Zhang 2420deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 243*f8a2cd67SUday Bondhugula patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>, 244*f8a2cd67SUday Bondhugula LoadOpOfSubViewFolder<memref::LoadOp>, 2450deeaacaSLei Zhang LoadOpOfSubViewFolder<vector::TransferReadOp>, 246*f8a2cd67SUday Bondhugula StoreOpOfSubViewFolder<AffineStoreOp>, 2470deeaacaSLei Zhang StoreOpOfSubViewFolder<memref::StoreOp>, 2480deeaacaSLei Zhang StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 2490deeaacaSLei Zhang patterns.getContext()); 2500deeaacaSLei Zhang } 2510deeaacaSLei Zhang 2520deeaacaSLei Zhang //===----------------------------------------------------------------------===// 2530deeaacaSLei Zhang // Pass registration 2540deeaacaSLei Zhang //===----------------------------------------------------------------------===// 2550deeaacaSLei Zhang 2560deeaacaSLei Zhang namespace { 2570deeaacaSLei Zhang 2580deeaacaSLei Zhang #define GEN_PASS_CLASSES 2590deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 2600deeaacaSLei Zhang 2610deeaacaSLei Zhang struct FoldSubViewOpsPass final 2620deeaacaSLei Zhang : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 2630deeaacaSLei Zhang void runOnOperation() override; 2640deeaacaSLei Zhang }; 2650deeaacaSLei Zhang 2660deeaacaSLei Zhang } // namespace 2670deeaacaSLei Zhang 2680deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() { 2690deeaacaSLei Zhang RewritePatternSet patterns(&getContext()); 2700deeaacaSLei Zhang memref::populateFoldSubViewOpPatterns(patterns); 2710deeaacaSLei Zhang (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 2720deeaacaSLei Zhang std::move(patterns)); 2730deeaacaSLei Zhang } 2740deeaacaSLei Zhang 2750deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 2760deeaacaSLei Zhang return std::make_unique<FoldSubViewOpsPass>(); 2770deeaacaSLei Zhang } 278