10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
20deeaacaSLei Zhang //
30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60deeaacaSLei Zhang //
70deeaacaSLei Zhang //===----------------------------------------------------------------------===//
80deeaacaSLei Zhang //
90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into
100deeaacaSLei Zhang // loading/storing from/to the original memref.
110deeaacaSLei Zhang //
120deeaacaSLei Zhang //===----------------------------------------------------------------------===//
130deeaacaSLei Zhang
1411a7635bSRiver Riddle #include "PassDetail.h"
15fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h"
180deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2034ff99a0STres Popp #include "mlir/IR/BuiltinTypes.h"
210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
226635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
230deeaacaSLei Zhang
240deeaacaSLei Zhang using namespace mlir;
250deeaacaSLei Zhang
260deeaacaSLei Zhang //===----------------------------------------------------------------------===//
270deeaacaSLei Zhang // Utility functions
280deeaacaSLei Zhang //===----------------------------------------------------------------------===//
290deeaacaSLei Zhang
300deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result
310deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the
320deeaacaSLei Zhang /// subview op. For example
330deeaacaSLei Zhang ///
340deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32>
350deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
360deeaacaSLei Zhang /// memref<4x4xf32, offset=?, strides=[?, ?]>
370deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
380deeaacaSLei Zhang ///
390deeaacaSLei Zhang /// could be folded into
400deeaacaSLei Zhang ///
410deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
420deeaacaSLei Zhang /// memref<12x42xf32>
430deeaacaSLei Zhang static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,memref::SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)440deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter,
450deeaacaSLei Zhang memref::SubViewOp subViewOp, ValueRange indices,
460deeaacaSLei Zhang SmallVectorImpl<Value> &sourceIndices) {
47fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
48fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
49fd15e2b8SMaheshRavishankar SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
500deeaacaSLei Zhang
51fd15e2b8SMaheshRavishankar SmallVector<Value> useIndices;
52fd15e2b8SMaheshRavishankar // Check if this is rank-reducing case. Then for every unit-dim size add a
53fd15e2b8SMaheshRavishankar // zero to the indices.
54fd15e2b8SMaheshRavishankar unsigned resultDim = 0;
556635c12aSBenjamin Kramer llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
564cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
576635c12aSBenjamin Kramer if (unusedDims.test(dim))
58a54f4eaeSMogball useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
594cf9bf6cSMaheshRavishankar else
60fd15e2b8SMaheshRavishankar useIndices.push_back(indices[resultDim++]);
61fd15e2b8SMaheshRavishankar }
62fd15e2b8SMaheshRavishankar if (useIndices.size() != mixedOffsets.size())
63fd15e2b8SMaheshRavishankar return failure();
64fd15e2b8SMaheshRavishankar sourceIndices.resize(useIndices.size());
65fd15e2b8SMaheshRavishankar for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
66fd15e2b8SMaheshRavishankar SmallVector<Value> dynamicOperands;
67fd15e2b8SMaheshRavishankar AffineExpr expr = rewriter.getAffineDimExpr(0);
68fd15e2b8SMaheshRavishankar unsigned numSymbols = 0;
69fd15e2b8SMaheshRavishankar dynamicOperands.push_back(useIndices[index]);
70fd15e2b8SMaheshRavishankar
71fd15e2b8SMaheshRavishankar // Multiply the stride;
72fd15e2b8SMaheshRavishankar if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
73fd15e2b8SMaheshRavishankar expr = expr * attr.cast<IntegerAttr>().getInt();
74fd15e2b8SMaheshRavishankar } else {
75fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedStrides[index].get<Value>());
76fd15e2b8SMaheshRavishankar expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
77fd15e2b8SMaheshRavishankar }
78fd15e2b8SMaheshRavishankar
79fd15e2b8SMaheshRavishankar // Add the offset.
80fd15e2b8SMaheshRavishankar if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
81fd15e2b8SMaheshRavishankar expr = expr + attr.cast<IntegerAttr>().getInt();
82fd15e2b8SMaheshRavishankar } else {
83fd15e2b8SMaheshRavishankar dynamicOperands.push_back(mixedOffsets[index].get<Value>());
84fd15e2b8SMaheshRavishankar expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
85fd15e2b8SMaheshRavishankar }
86fd15e2b8SMaheshRavishankar Location loc = subViewOp.getLoc();
87fd15e2b8SMaheshRavishankar sourceIndices[index] = rewriter.create<AffineApplyOp>(
88fd15e2b8SMaheshRavishankar loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
890deeaacaSLei Zhang }
900deeaacaSLei Zhang return success();
910deeaacaSLei Zhang }
920deeaacaSLei Zhang
930deeaacaSLei Zhang /// Helpers to access the memref operand for each op.
94f8a2cd67SUday Bondhugula template <typename LoadOrStoreOpTy>
getMemRefOperand(LoadOrStoreOpTy op)95f8a2cd67SUday Bondhugula static Value getMemRefOperand(LoadOrStoreOpTy op) {
9604235d07SJacques Pienaar return op.getMemref();
97f8a2cd67SUday Bondhugula }
980deeaacaSLei Zhang
getMemRefOperand(vector::TransferReadOp op)997c38fd60SJacques Pienaar static Value getMemRefOperand(vector::TransferReadOp op) {
1007c38fd60SJacques Pienaar return op.getSource();
1017c38fd60SJacques Pienaar }
1020deeaacaSLei Zhang
getMemRefOperand(vector::TransferWriteOp op)1030deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) {
1047c38fd60SJacques Pienaar return op.getSource();
1050deeaacaSLei Zhang }
1060deeaacaSLei Zhang
1074cf9bf6cSMaheshRavishankar /// Given the permutation map of the original
1084cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the
1094cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it.
getPermutationMapAttr(MLIRContext * context,memref::SubViewOp subViewOp,AffineMap currPermutationMap)110c537a943SNicolas Vasilache static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
1114cf9bf6cSMaheshRavishankar memref::SubViewOp subViewOp,
1124cf9bf6cSMaheshRavishankar AffineMap currPermutationMap) {
1136635c12aSBenjamin Kramer llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1144cf9bf6cSMaheshRavishankar SmallVector<AffineExpr> exprs;
1154cf9bf6cSMaheshRavishankar int64_t sourceRank = subViewOp.getSourceType().getRank();
1164cf9bf6cSMaheshRavishankar for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
1176635c12aSBenjamin Kramer if (unusedDims.test(dim))
1184cf9bf6cSMaheshRavishankar continue;
119b12e4c17Sthomasraoux exprs.push_back(getAffineDimExpr(dim, context));
1204cf9bf6cSMaheshRavishankar }
1214cf9bf6cSMaheshRavishankar auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
122c537a943SNicolas Vasilache return AffineMapAttr::get(
123c537a943SNicolas Vasilache currPermutationMap.compose(resultDimToSourceDimMap));
1244cf9bf6cSMaheshRavishankar }
1254cf9bf6cSMaheshRavishankar
1260deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1270deeaacaSLei Zhang // Patterns
1280deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1290deeaacaSLei Zhang
1300deeaacaSLei Zhang namespace {
1310deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation.
1320deeaacaSLei Zhang template <typename OpTy>
1330deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1340deeaacaSLei Zhang public:
1350deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern;
1360deeaacaSLei Zhang
1370deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy loadOp,
1380deeaacaSLei Zhang PatternRewriter &rewriter) const override;
1390deeaacaSLei Zhang
1400deeaacaSLei Zhang private:
1410deeaacaSLei Zhang void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
1420deeaacaSLei Zhang ArrayRef<Value> sourceIndices,
1430deeaacaSLei Zhang PatternRewriter &rewriter) const;
1440deeaacaSLei Zhang };
1450deeaacaSLei Zhang
1460deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation.
1470deeaacaSLei Zhang template <typename OpTy>
1480deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1490deeaacaSLei Zhang public:
1500deeaacaSLei Zhang using OpRewritePattern<OpTy>::OpRewritePattern;
1510deeaacaSLei Zhang
1520deeaacaSLei Zhang LogicalResult matchAndRewrite(OpTy storeOp,
1530deeaacaSLei Zhang PatternRewriter &rewriter) const override;
1540deeaacaSLei Zhang
1550deeaacaSLei Zhang private:
1560deeaacaSLei Zhang void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
1570deeaacaSLei Zhang ArrayRef<Value> sourceIndices,
1580deeaacaSLei Zhang PatternRewriter &rewriter) const;
1590deeaacaSLei Zhang };
1600deeaacaSLei Zhang
161f8a2cd67SUday Bondhugula template <typename LoadOpTy>
replaceOp(LoadOpTy loadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const162f8a2cd67SUday Bondhugula void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
163f8a2cd67SUday Bondhugula LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
164f8a2cd67SUday Bondhugula PatternRewriter &rewriter) const {
165*136d746eSJacques Pienaar rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
1660deeaacaSLei Zhang sourceIndices);
1670deeaacaSLei Zhang }
1680deeaacaSLei Zhang
1690deeaacaSLei Zhang template <>
replaceOp(vector::TransferReadOp transferReadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const1700deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
171c537a943SNicolas Vasilache vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
1720deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
173c537a943SNicolas Vasilache // TODO: support 0-d corner case.
174c537a943SNicolas Vasilache if (transferReadOp.getTransferRank() == 0)
175c537a943SNicolas Vasilache return;
1760deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
177*136d746eSJacques Pienaar transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
178c537a943SNicolas Vasilache sourceIndices,
179c537a943SNicolas Vasilache getPermutationMapAttr(rewriter.getContext(), subViewOp,
1807c38fd60SJacques Pienaar transferReadOp.getPermutationMap()),
1817c38fd60SJacques Pienaar transferReadOp.getPadding(),
1827c38fd60SJacques Pienaar /*mask=*/Value(), transferReadOp.getInBoundsAttr());
1830deeaacaSLei Zhang }
1840deeaacaSLei Zhang
185f8a2cd67SUday Bondhugula template <typename StoreOpTy>
replaceOp(StoreOpTy storeOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const186f8a2cd67SUday Bondhugula void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
187f8a2cd67SUday Bondhugula StoreOpTy storeOp, memref::SubViewOp subViewOp,
1880deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
18904235d07SJacques Pienaar rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
190*136d746eSJacques Pienaar subViewOp.getSource(), sourceIndices);
1910deeaacaSLei Zhang }
1920deeaacaSLei Zhang
1930deeaacaSLei Zhang template <>
replaceOp(vector::TransferWriteOp transferWriteOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const1940deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
1950deeaacaSLei Zhang vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
1960deeaacaSLei Zhang ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
197c537a943SNicolas Vasilache // TODO: support 0-d corner case.
198c537a943SNicolas Vasilache if (transferWriteOp.getTransferRank() == 0)
199c537a943SNicolas Vasilache return;
2000deeaacaSLei Zhang rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
201*136d746eSJacques Pienaar transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
2024cf9bf6cSMaheshRavishankar sourceIndices,
203c537a943SNicolas Vasilache getPermutationMapAttr(rewriter.getContext(), subViewOp,
2047c38fd60SJacques Pienaar transferWriteOp.getPermutationMap()),
2057c38fd60SJacques Pienaar transferWriteOp.getInBoundsAttr());
2060deeaacaSLei Zhang }
2070deeaacaSLei Zhang } // namespace
2080deeaacaSLei Zhang
2090deeaacaSLei Zhang template <typename OpTy>
2100deeaacaSLei Zhang LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const2110deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
2120deeaacaSLei Zhang PatternRewriter &rewriter) const {
2130deeaacaSLei Zhang auto subViewOp =
2140deeaacaSLei Zhang getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
2150deeaacaSLei Zhang if (!subViewOp)
2160deeaacaSLei Zhang return failure();
2170deeaacaSLei Zhang
2180deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices;
21934ff99a0STres Popp if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
2207c38fd60SJacques Pienaar loadOp.getIndices(), sourceIndices)))
2210deeaacaSLei Zhang return failure();
2220deeaacaSLei Zhang
2230deeaacaSLei Zhang replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
2240deeaacaSLei Zhang return success();
2250deeaacaSLei Zhang }
2260deeaacaSLei Zhang
2270deeaacaSLei Zhang template <typename OpTy>
2280deeaacaSLei Zhang LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const2290deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
2300deeaacaSLei Zhang PatternRewriter &rewriter) const {
2310deeaacaSLei Zhang auto subViewOp =
2320deeaacaSLei Zhang getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
2330deeaacaSLei Zhang if (!subViewOp)
2340deeaacaSLei Zhang return failure();
2350deeaacaSLei Zhang
2360deeaacaSLei Zhang SmallVector<Value, 4> sourceIndices;
2370deeaacaSLei Zhang if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
2387c38fd60SJacques Pienaar storeOp.getIndices(), sourceIndices)))
2390deeaacaSLei Zhang return failure();
2400deeaacaSLei Zhang
2410deeaacaSLei Zhang replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
2420deeaacaSLei Zhang return success();
2430deeaacaSLei Zhang }
2440deeaacaSLei Zhang
populateFoldSubViewOpPatterns(RewritePatternSet & patterns)2450deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
246f8a2cd67SUday Bondhugula patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
247f8a2cd67SUday Bondhugula LoadOpOfSubViewFolder<memref::LoadOp>,
2480deeaacaSLei Zhang LoadOpOfSubViewFolder<vector::TransferReadOp>,
249f8a2cd67SUday Bondhugula StoreOpOfSubViewFolder<AffineStoreOp>,
2500deeaacaSLei Zhang StoreOpOfSubViewFolder<memref::StoreOp>,
2510deeaacaSLei Zhang StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
2520deeaacaSLei Zhang patterns.getContext());
2530deeaacaSLei Zhang }
2540deeaacaSLei Zhang
2550deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2560deeaacaSLei Zhang // Pass registration
2570deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2580deeaacaSLei Zhang
2590deeaacaSLei Zhang namespace {
2600deeaacaSLei Zhang
2610deeaacaSLei Zhang struct FoldSubViewOpsPass final
2620deeaacaSLei Zhang : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
2630deeaacaSLei Zhang void runOnOperation() override;
2640deeaacaSLei Zhang };
2650deeaacaSLei Zhang
2660deeaacaSLei Zhang } // namespace
2670deeaacaSLei Zhang
runOnOperation()2680deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() {
2690deeaacaSLei Zhang RewritePatternSet patterns(&getContext());
2700deeaacaSLei Zhang memref::populateFoldSubViewOpPatterns(patterns);
27111a7635bSRiver Riddle (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2720deeaacaSLei Zhang }
2730deeaacaSLei Zhang
createFoldSubViewOpsPass()2740deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
2750deeaacaSLei Zhang return std::make_unique<FoldSubViewOpsPass>();
2760deeaacaSLei Zhang }
277