10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 20deeaacaSLei Zhang // 30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information. 50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 60deeaacaSLei Zhang // 70deeaacaSLei Zhang //===----------------------------------------------------------------------===// 80deeaacaSLei Zhang // 90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into 100deeaacaSLei Zhang // loading/storing from/to the original memref. 110deeaacaSLei Zhang // 120deeaacaSLei Zhang //===----------------------------------------------------------------------===// 130deeaacaSLei Zhang 14fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h" 15*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 160deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h" 170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h" 180deeaacaSLei Zhang #include "mlir/Dialect/StandardOps/IR/Ops.h" 190deeaacaSLei Zhang #include "mlir/Dialect/Vector/VectorOps.h" 200deeaacaSLei Zhang #include "mlir/IR/BuiltinTypes.h" 210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 220deeaacaSLei Zhang 230deeaacaSLei Zhang using namespace mlir; 240deeaacaSLei Zhang 250deeaacaSLei Zhang //===----------------------------------------------------------------------===// 260deeaacaSLei Zhang // Utility functions 270deeaacaSLei Zhang //===----------------------------------------------------------------------===// 280deeaacaSLei Zhang 290deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result 300deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the 310deeaacaSLei Zhang /// subview op. For example 320deeaacaSLei Zhang /// 330deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32> 340deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 350deeaacaSLei Zhang /// memref<4x4xf32, offset=?, strides=[?, ?]> 360deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 370deeaacaSLei Zhang /// 380deeaacaSLei Zhang /// could be folded into 390deeaacaSLei Zhang /// 400deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 410deeaacaSLei Zhang /// memref<12x42xf32> 420deeaacaSLei Zhang static LogicalResult 430deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter, 440deeaacaSLei Zhang memref::SubViewOp subViewOp, ValueRange indices, 450deeaacaSLei Zhang SmallVectorImpl<Value> &sourceIndices) { 46fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 47fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 48fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 490deeaacaSLei Zhang 50fd15e2b8SMaheshRavishankar SmallVector<Value> useIndices; 51fd15e2b8SMaheshRavishankar // Check if this is rank-reducing case. Then for every unit-dim size add a 52fd15e2b8SMaheshRavishankar // zero to the indices. 53fd15e2b8SMaheshRavishankar unsigned resultDim = 0; 544cf9bf6cSMaheshRavishankar llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 554cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) { 564cf9bf6cSMaheshRavishankar if (unusedDims.count(dim)) 57*a54f4eaeSMogball useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0)); 584cf9bf6cSMaheshRavishankar else 59fd15e2b8SMaheshRavishankar useIndices.push_back(indices[resultDim++]); 60fd15e2b8SMaheshRavishankar } 61fd15e2b8SMaheshRavishankar if (useIndices.size() != mixedOffsets.size()) 62fd15e2b8SMaheshRavishankar return failure(); 63fd15e2b8SMaheshRavishankar sourceIndices.resize(useIndices.size()); 64fd15e2b8SMaheshRavishankar for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 65fd15e2b8SMaheshRavishankar SmallVector<Value> dynamicOperands; 66fd15e2b8SMaheshRavishankar AffineExpr expr = rewriter.getAffineDimExpr(0); 67fd15e2b8SMaheshRavishankar unsigned numSymbols = 0; 68fd15e2b8SMaheshRavishankar dynamicOperands.push_back(useIndices[index]); 69fd15e2b8SMaheshRavishankar 70fd15e2b8SMaheshRavishankar // Multiply the stride; 71fd15e2b8SMaheshRavishankar if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 72fd15e2b8SMaheshRavishankar expr = expr * attr.cast<IntegerAttr>().getInt(); 73fd15e2b8SMaheshRavishankar } else { 74fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedStrides[index].get<Value>()); 75fd15e2b8SMaheshRavishankar expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 76fd15e2b8SMaheshRavishankar } 77fd15e2b8SMaheshRavishankar 78fd15e2b8SMaheshRavishankar // Add the offset. 79fd15e2b8SMaheshRavishankar if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 80fd15e2b8SMaheshRavishankar expr = expr + attr.cast<IntegerAttr>().getInt(); 81fd15e2b8SMaheshRavishankar } else { 82fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 83fd15e2b8SMaheshRavishankar expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 84fd15e2b8SMaheshRavishankar } 85fd15e2b8SMaheshRavishankar Location loc = subViewOp.getLoc(); 86fd15e2b8SMaheshRavishankar sourceIndices[index] = rewriter.create<AffineApplyOp>( 87fd15e2b8SMaheshRavishankar loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 880deeaacaSLei Zhang } 890deeaacaSLei Zhang return success(); 900deeaacaSLei Zhang } 910deeaacaSLei Zhang 920deeaacaSLei Zhang /// Helpers to access the memref operand for each op. 930deeaacaSLei Zhang static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } 940deeaacaSLei Zhang 950deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 960deeaacaSLei Zhang 970deeaacaSLei Zhang static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } 980deeaacaSLei Zhang 990deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) { 1000deeaacaSLei Zhang return op.source(); 1010deeaacaSLei Zhang } 1020deeaacaSLei Zhang 1034cf9bf6cSMaheshRavishankar /// Given the permutation map of the original 1044cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the 1054cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it. 1064cf9bf6cSMaheshRavishankar static AffineMap getPermutationMap(MLIRContext *context, 1074cf9bf6cSMaheshRavishankar memref::SubViewOp subViewOp, 1084cf9bf6cSMaheshRavishankar AffineMap currPermutationMap) { 1094cf9bf6cSMaheshRavishankar llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 1104cf9bf6cSMaheshRavishankar SmallVector<AffineExpr> exprs; 1114cf9bf6cSMaheshRavishankar int64_t sourceRank = subViewOp.getSourceType().getRank(); 1124cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 1134cf9bf6cSMaheshRavishankar if (unusedDims.count(dim)) 1144cf9bf6cSMaheshRavishankar continue; 115b12e4c17Sthomasraoux exprs.push_back(getAffineDimExpr(dim, context)); 1164cf9bf6cSMaheshRavishankar } 1174cf9bf6cSMaheshRavishankar auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); 1184cf9bf6cSMaheshRavishankar return currPermutationMap.compose(resultDimToSourceDimMap); 1194cf9bf6cSMaheshRavishankar } 1204cf9bf6cSMaheshRavishankar 1210deeaacaSLei Zhang //===----------------------------------------------------------------------===// 1220deeaacaSLei Zhang // Patterns 1230deeaacaSLei Zhang //===----------------------------------------------------------------------===// 1240deeaacaSLei Zhang 1250deeaacaSLei Zhang namespace { 1260deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation. 1270deeaacaSLei Zhang template <typename OpTy> 1280deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 1290deeaacaSLei Zhang public: 1300deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern; 1310deeaacaSLei Zhang 1320deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy loadOp, 1330deeaacaSLei Zhang PatternRewriter &rewriter) const override; 1340deeaacaSLei Zhang 1350deeaacaSLei Zhang private: 1360deeaacaSLei Zhang void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 1370deeaacaSLei Zhang ArrayRef<Value> sourceIndices, 1380deeaacaSLei Zhang PatternRewriter &rewriter) const; 1390deeaacaSLei Zhang }; 1400deeaacaSLei Zhang 1410deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation. 1420deeaacaSLei Zhang template <typename OpTy> 1430deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 1440deeaacaSLei Zhang public: 1450deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern; 1460deeaacaSLei Zhang 1470deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy storeOp, 1480deeaacaSLei Zhang PatternRewriter &rewriter) const override; 1490deeaacaSLei Zhang 1500deeaacaSLei Zhang private: 1510deeaacaSLei Zhang void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 1520deeaacaSLei Zhang ArrayRef<Value> sourceIndices, 1530deeaacaSLei Zhang PatternRewriter &rewriter) const; 1540deeaacaSLei Zhang }; 1550deeaacaSLei Zhang 1560deeaacaSLei Zhang template <> 1570deeaacaSLei Zhang void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp( 1580deeaacaSLei Zhang memref::LoadOp loadOp, memref::SubViewOp subViewOp, 1590deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 1600deeaacaSLei Zhang rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(), 1610deeaacaSLei Zhang sourceIndices); 1620deeaacaSLei Zhang } 1630deeaacaSLei Zhang 1640deeaacaSLei Zhang template <> 1650deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 1660deeaacaSLei Zhang vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, 1670deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 1680deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 1690deeaacaSLei Zhang loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, 1704cf9bf6cSMaheshRavishankar getPermutationMap(rewriter.getContext(), subViewOp, 1714cf9bf6cSMaheshRavishankar loadOp.permutation_map()), 1724cf9bf6cSMaheshRavishankar loadOp.padding(), loadOp.in_boundsAttr()); 1730deeaacaSLei Zhang } 1740deeaacaSLei Zhang 1750deeaacaSLei Zhang template <> 1760deeaacaSLei Zhang void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp( 1770deeaacaSLei Zhang memref::StoreOp storeOp, memref::SubViewOp subViewOp, 1780deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 1790deeaacaSLei Zhang rewriter.replaceOpWithNewOp<memref::StoreOp>( 1800deeaacaSLei Zhang storeOp, storeOp.value(), subViewOp.source(), sourceIndices); 1810deeaacaSLei Zhang } 1820deeaacaSLei Zhang 1830deeaacaSLei Zhang template <> 1840deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 1850deeaacaSLei Zhang vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 1860deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 1870deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 1880deeaacaSLei Zhang transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 1894cf9bf6cSMaheshRavishankar sourceIndices, 1904cf9bf6cSMaheshRavishankar getPermutationMap(rewriter.getContext(), subViewOp, 1914cf9bf6cSMaheshRavishankar transferWriteOp.permutation_map()), 1920deeaacaSLei Zhang transferWriteOp.in_boundsAttr()); 1930deeaacaSLei Zhang } 1940deeaacaSLei Zhang } // namespace 1950deeaacaSLei Zhang 1960deeaacaSLei Zhang template <typename OpTy> 1970deeaacaSLei Zhang LogicalResult 1980deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 1990deeaacaSLei Zhang PatternRewriter &rewriter) const { 2000deeaacaSLei Zhang auto subViewOp = 2010deeaacaSLei Zhang getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 2020deeaacaSLei Zhang if (!subViewOp) 2030deeaacaSLei Zhang return failure(); 2040deeaacaSLei Zhang 2050deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices; 2060deeaacaSLei Zhang if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 2070deeaacaSLei Zhang loadOp.indices(), sourceIndices))) 2080deeaacaSLei Zhang return failure(); 2090deeaacaSLei Zhang 2100deeaacaSLei Zhang replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 2110deeaacaSLei Zhang return success(); 2120deeaacaSLei Zhang } 2130deeaacaSLei Zhang 2140deeaacaSLei Zhang template <typename OpTy> 2150deeaacaSLei Zhang LogicalResult 2160deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 2170deeaacaSLei Zhang PatternRewriter &rewriter) const { 2180deeaacaSLei Zhang auto subViewOp = 2190deeaacaSLei Zhang getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 2200deeaacaSLei Zhang if (!subViewOp) 2210deeaacaSLei Zhang return failure(); 2220deeaacaSLei Zhang 2230deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices; 2240deeaacaSLei Zhang if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 2250deeaacaSLei Zhang storeOp.indices(), sourceIndices))) 2260deeaacaSLei Zhang return failure(); 2270deeaacaSLei Zhang 2280deeaacaSLei Zhang replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 2290deeaacaSLei Zhang return success(); 2300deeaacaSLei Zhang } 2310deeaacaSLei Zhang 2320deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 2330deeaacaSLei Zhang patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>, 2340deeaacaSLei Zhang LoadOpOfSubViewFolder<vector::TransferReadOp>, 2350deeaacaSLei Zhang StoreOpOfSubViewFolder<memref::StoreOp>, 2360deeaacaSLei Zhang StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 2370deeaacaSLei Zhang patterns.getContext()); 2380deeaacaSLei Zhang } 2390deeaacaSLei Zhang 2400deeaacaSLei Zhang //===----------------------------------------------------------------------===// 2410deeaacaSLei Zhang // Pass registration 2420deeaacaSLei Zhang //===----------------------------------------------------------------------===// 2430deeaacaSLei Zhang 2440deeaacaSLei Zhang namespace { 2450deeaacaSLei Zhang 2460deeaacaSLei Zhang #define GEN_PASS_CLASSES 2470deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 2480deeaacaSLei Zhang 2490deeaacaSLei Zhang struct FoldSubViewOpsPass final 2500deeaacaSLei Zhang : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 2510deeaacaSLei Zhang void runOnOperation() override; 2520deeaacaSLei Zhang }; 2530deeaacaSLei Zhang 2540deeaacaSLei Zhang } // namespace 2550deeaacaSLei Zhang 2560deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() { 2570deeaacaSLei Zhang RewritePatternSet patterns(&getContext()); 2580deeaacaSLei Zhang memref::populateFoldSubViewOpPatterns(patterns); 2590deeaacaSLei Zhang (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 2600deeaacaSLei Zhang std::move(patterns)); 2610deeaacaSLei Zhang } 2620deeaacaSLei Zhang 2630deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 2640deeaacaSLei Zhang return std::make_unique<FoldSubViewOpsPass>(); 2650deeaacaSLei Zhang } 266