10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
20deeaacaSLei Zhang //
30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60deeaacaSLei Zhang //
70deeaacaSLei Zhang //===----------------------------------------------------------------------===//
80deeaacaSLei Zhang //
90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into
100deeaacaSLei Zhang // loading/storing from/to the original memref.
110deeaacaSLei Zhang //
120deeaacaSLei Zhang //===----------------------------------------------------------------------===//
130deeaacaSLei Zhang 
1411a7635bSRiver Riddle #include "PassDetail.h"
15fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
16a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h"
180deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h"
1999ef9eebSMatthias Springer #include "mlir/Dialect/Vector/IR/VectorOps.h"
2034ff99a0STres Popp #include "mlir/IR/BuiltinTypes.h"
210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
226635c12aSBenjamin Kramer #include "llvm/ADT/SmallBitVector.h"
230deeaacaSLei Zhang 
240deeaacaSLei Zhang using namespace mlir;
250deeaacaSLei Zhang 
260deeaacaSLei Zhang //===----------------------------------------------------------------------===//
270deeaacaSLei Zhang // Utility functions
280deeaacaSLei Zhang //===----------------------------------------------------------------------===//
290deeaacaSLei Zhang 
300deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result
310deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the
320deeaacaSLei Zhang /// subview op. For example
330deeaacaSLei Zhang ///
340deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32>
350deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
360deeaacaSLei Zhang ///          memref<4x4xf32, offset=?, strides=[?, ?]>
370deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
380deeaacaSLei Zhang ///
390deeaacaSLei Zhang /// could be folded into
400deeaacaSLei Zhang ///
410deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
420deeaacaSLei Zhang ///          memref<12x42xf32>
430deeaacaSLei Zhang static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,memref::SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)440deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter,
450deeaacaSLei Zhang                      memref::SubViewOp subViewOp, ValueRange indices,
460deeaacaSLei Zhang                      SmallVectorImpl<Value> &sourceIndices) {
47fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
48fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
49fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
500deeaacaSLei Zhang 
51fd15e2b8SMaheshRavishankar   SmallVector<Value> useIndices;
52fd15e2b8SMaheshRavishankar   // Check if this is rank-reducing case. Then for every unit-dim size add a
53fd15e2b8SMaheshRavishankar   // zero to the indices.
54fd15e2b8SMaheshRavishankar   unsigned resultDim = 0;
556635c12aSBenjamin Kramer   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
564cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
576635c12aSBenjamin Kramer     if (unusedDims.test(dim))
58a54f4eaeSMogball       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
594cf9bf6cSMaheshRavishankar     else
60fd15e2b8SMaheshRavishankar       useIndices.push_back(indices[resultDim++]);
61fd15e2b8SMaheshRavishankar   }
62fd15e2b8SMaheshRavishankar   if (useIndices.size() != mixedOffsets.size())
63fd15e2b8SMaheshRavishankar     return failure();
64fd15e2b8SMaheshRavishankar   sourceIndices.resize(useIndices.size());
65fd15e2b8SMaheshRavishankar   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
66fd15e2b8SMaheshRavishankar     SmallVector<Value> dynamicOperands;
67fd15e2b8SMaheshRavishankar     AffineExpr expr = rewriter.getAffineDimExpr(0);
68fd15e2b8SMaheshRavishankar     unsigned numSymbols = 0;
69fd15e2b8SMaheshRavishankar     dynamicOperands.push_back(useIndices[index]);
70fd15e2b8SMaheshRavishankar 
71fd15e2b8SMaheshRavishankar     // Multiply the stride;
72fd15e2b8SMaheshRavishankar     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
73fd15e2b8SMaheshRavishankar       expr = expr * attr.cast<IntegerAttr>().getInt();
74fd15e2b8SMaheshRavishankar     } else {
75fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedStrides[index].get<Value>());
76fd15e2b8SMaheshRavishankar       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
77fd15e2b8SMaheshRavishankar     }
78fd15e2b8SMaheshRavishankar 
79fd15e2b8SMaheshRavishankar     // Add the offset.
80fd15e2b8SMaheshRavishankar     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
81fd15e2b8SMaheshRavishankar       expr = expr + attr.cast<IntegerAttr>().getInt();
82fd15e2b8SMaheshRavishankar     } else {
83fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
84fd15e2b8SMaheshRavishankar       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
85fd15e2b8SMaheshRavishankar     }
86fd15e2b8SMaheshRavishankar     Location loc = subViewOp.getLoc();
87fd15e2b8SMaheshRavishankar     sourceIndices[index] = rewriter.create<AffineApplyOp>(
88fd15e2b8SMaheshRavishankar         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
890deeaacaSLei Zhang   }
900deeaacaSLei Zhang   return success();
910deeaacaSLei Zhang }
920deeaacaSLei Zhang 
930deeaacaSLei Zhang /// Helpers to access the memref operand for each op.
94f8a2cd67SUday Bondhugula template <typename LoadOrStoreOpTy>
getMemRefOperand(LoadOrStoreOpTy op)95f8a2cd67SUday Bondhugula static Value getMemRefOperand(LoadOrStoreOpTy op) {
9604235d07SJacques Pienaar   return op.getMemref();
97f8a2cd67SUday Bondhugula }
980deeaacaSLei Zhang 
getMemRefOperand(vector::TransferReadOp op)997c38fd60SJacques Pienaar static Value getMemRefOperand(vector::TransferReadOp op) {
1007c38fd60SJacques Pienaar   return op.getSource();
1017c38fd60SJacques Pienaar }
1020deeaacaSLei Zhang 
getMemRefOperand(vector::TransferWriteOp op)1030deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) {
1047c38fd60SJacques Pienaar   return op.getSource();
1050deeaacaSLei Zhang }
1060deeaacaSLei Zhang 
1074cf9bf6cSMaheshRavishankar /// Given the permutation map of the original
1084cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the
1094cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it.
getPermutationMapAttr(MLIRContext * context,memref::SubViewOp subViewOp,AffineMap currPermutationMap)110c537a943SNicolas Vasilache static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
1114cf9bf6cSMaheshRavishankar                                            memref::SubViewOp subViewOp,
1124cf9bf6cSMaheshRavishankar                                            AffineMap currPermutationMap) {
1136635c12aSBenjamin Kramer   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
1144cf9bf6cSMaheshRavishankar   SmallVector<AffineExpr> exprs;
1154cf9bf6cSMaheshRavishankar   int64_t sourceRank = subViewOp.getSourceType().getRank();
1164cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
1176635c12aSBenjamin Kramer     if (unusedDims.test(dim))
1184cf9bf6cSMaheshRavishankar       continue;
119b12e4c17Sthomasraoux     exprs.push_back(getAffineDimExpr(dim, context));
1204cf9bf6cSMaheshRavishankar   }
1214cf9bf6cSMaheshRavishankar   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
122c537a943SNicolas Vasilache   return AffineMapAttr::get(
123c537a943SNicolas Vasilache       currPermutationMap.compose(resultDimToSourceDimMap));
1244cf9bf6cSMaheshRavishankar }
1254cf9bf6cSMaheshRavishankar 
1260deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1270deeaacaSLei Zhang // Patterns
1280deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1290deeaacaSLei Zhang 
1300deeaacaSLei Zhang namespace {
1310deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation.
1320deeaacaSLei Zhang template <typename OpTy>
1330deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1340deeaacaSLei Zhang public:
1350deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1360deeaacaSLei Zhang 
1370deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy loadOp,
1380deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1390deeaacaSLei Zhang 
1400deeaacaSLei Zhang private:
1410deeaacaSLei Zhang   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
1420deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1430deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1440deeaacaSLei Zhang };
1450deeaacaSLei Zhang 
1460deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation.
1470deeaacaSLei Zhang template <typename OpTy>
1480deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1490deeaacaSLei Zhang public:
1500deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1510deeaacaSLei Zhang 
1520deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy storeOp,
1530deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1540deeaacaSLei Zhang 
1550deeaacaSLei Zhang private:
1560deeaacaSLei Zhang   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
1570deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1580deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1590deeaacaSLei Zhang };
1600deeaacaSLei Zhang 
161f8a2cd67SUday Bondhugula template <typename LoadOpTy>
replaceOp(LoadOpTy loadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const162f8a2cd67SUday Bondhugula void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
163f8a2cd67SUday Bondhugula     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
164f8a2cd67SUday Bondhugula     PatternRewriter &rewriter) const {
165*136d746eSJacques Pienaar   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
1660deeaacaSLei Zhang                                         sourceIndices);
1670deeaacaSLei Zhang }
1680deeaacaSLei Zhang 
1690deeaacaSLei Zhang template <>
replaceOp(vector::TransferReadOp transferReadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const1700deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
171c537a943SNicolas Vasilache     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
1720deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
173c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
174c537a943SNicolas Vasilache   if (transferReadOp.getTransferRank() == 0)
175c537a943SNicolas Vasilache     return;
1760deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
177*136d746eSJacques Pienaar       transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
178c537a943SNicolas Vasilache       sourceIndices,
179c537a943SNicolas Vasilache       getPermutationMapAttr(rewriter.getContext(), subViewOp,
1807c38fd60SJacques Pienaar                             transferReadOp.getPermutationMap()),
1817c38fd60SJacques Pienaar       transferReadOp.getPadding(),
1827c38fd60SJacques Pienaar       /*mask=*/Value(), transferReadOp.getInBoundsAttr());
1830deeaacaSLei Zhang }
1840deeaacaSLei Zhang 
185f8a2cd67SUday Bondhugula template <typename StoreOpTy>
replaceOp(StoreOpTy storeOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const186f8a2cd67SUday Bondhugula void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
187f8a2cd67SUday Bondhugula     StoreOpTy storeOp, memref::SubViewOp subViewOp,
1880deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
18904235d07SJacques Pienaar   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
190*136d746eSJacques Pienaar                                          subViewOp.getSource(), sourceIndices);
1910deeaacaSLei Zhang }
1920deeaacaSLei Zhang 
1930deeaacaSLei Zhang template <>
replaceOp(vector::TransferWriteOp transferWriteOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const1940deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
1950deeaacaSLei Zhang     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
1960deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
197c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
198c537a943SNicolas Vasilache   if (transferWriteOp.getTransferRank() == 0)
199c537a943SNicolas Vasilache     return;
2000deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
201*136d746eSJacques Pienaar       transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
2024cf9bf6cSMaheshRavishankar       sourceIndices,
203c537a943SNicolas Vasilache       getPermutationMapAttr(rewriter.getContext(), subViewOp,
2047c38fd60SJacques Pienaar                             transferWriteOp.getPermutationMap()),
2057c38fd60SJacques Pienaar       transferWriteOp.getInBoundsAttr());
2060deeaacaSLei Zhang }
2070deeaacaSLei Zhang } // namespace
2080deeaacaSLei Zhang 
2090deeaacaSLei Zhang template <typename OpTy>
2100deeaacaSLei Zhang LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const2110deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
2120deeaacaSLei Zhang                                              PatternRewriter &rewriter) const {
2130deeaacaSLei Zhang   auto subViewOp =
2140deeaacaSLei Zhang       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
2150deeaacaSLei Zhang   if (!subViewOp)
2160deeaacaSLei Zhang     return failure();
2170deeaacaSLei Zhang 
2180deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
21934ff99a0STres Popp   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
2207c38fd60SJacques Pienaar                                   loadOp.getIndices(), sourceIndices)))
2210deeaacaSLei Zhang     return failure();
2220deeaacaSLei Zhang 
2230deeaacaSLei Zhang   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
2240deeaacaSLei Zhang   return success();
2250deeaacaSLei Zhang }
2260deeaacaSLei Zhang 
2270deeaacaSLei Zhang template <typename OpTy>
2280deeaacaSLei Zhang LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const2290deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
2300deeaacaSLei Zhang                                               PatternRewriter &rewriter) const {
2310deeaacaSLei Zhang   auto subViewOp =
2320deeaacaSLei Zhang       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
2330deeaacaSLei Zhang   if (!subViewOp)
2340deeaacaSLei Zhang     return failure();
2350deeaacaSLei Zhang 
2360deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
2370deeaacaSLei Zhang   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
2387c38fd60SJacques Pienaar                                   storeOp.getIndices(), sourceIndices)))
2390deeaacaSLei Zhang     return failure();
2400deeaacaSLei Zhang 
2410deeaacaSLei Zhang   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
2420deeaacaSLei Zhang   return success();
2430deeaacaSLei Zhang }
2440deeaacaSLei Zhang 
populateFoldSubViewOpPatterns(RewritePatternSet & patterns)2450deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
246f8a2cd67SUday Bondhugula   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
247f8a2cd67SUday Bondhugula                LoadOpOfSubViewFolder<memref::LoadOp>,
2480deeaacaSLei Zhang                LoadOpOfSubViewFolder<vector::TransferReadOp>,
249f8a2cd67SUday Bondhugula                StoreOpOfSubViewFolder<AffineStoreOp>,
2500deeaacaSLei Zhang                StoreOpOfSubViewFolder<memref::StoreOp>,
2510deeaacaSLei Zhang                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
2520deeaacaSLei Zhang       patterns.getContext());
2530deeaacaSLei Zhang }
2540deeaacaSLei Zhang 
2550deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2560deeaacaSLei Zhang // Pass registration
2570deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2580deeaacaSLei Zhang 
2590deeaacaSLei Zhang namespace {
2600deeaacaSLei Zhang 
2610deeaacaSLei Zhang struct FoldSubViewOpsPass final
2620deeaacaSLei Zhang     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
2630deeaacaSLei Zhang   void runOnOperation() override;
2640deeaacaSLei Zhang };
2650deeaacaSLei Zhang 
2660deeaacaSLei Zhang } // namespace
2670deeaacaSLei Zhang 
runOnOperation()2680deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() {
2690deeaacaSLei Zhang   RewritePatternSet patterns(&getContext());
2700deeaacaSLei Zhang   memref::populateFoldSubViewOpPatterns(patterns);
27111a7635bSRiver Riddle   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
2720deeaacaSLei Zhang }
2730deeaacaSLei Zhang 
createFoldSubViewOpsPass()2740deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
2750deeaacaSLei Zhang   return std::make_unique<FoldSubViewOpsPass>();
2760deeaacaSLei Zhang }
277