10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
20deeaacaSLei Zhang //
30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60deeaacaSLei Zhang //
70deeaacaSLei Zhang //===----------------------------------------------------------------------===//
80deeaacaSLei Zhang //
90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into
100deeaacaSLei Zhang // loading/storing from/to the original memref.
110deeaacaSLei Zhang //
120deeaacaSLei Zhang //===----------------------------------------------------------------------===//
130deeaacaSLei Zhang 
14fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
15*a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
160deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h"
170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h"
180deeaacaSLei Zhang #include "mlir/Dialect/StandardOps/IR/Ops.h"
190deeaacaSLei Zhang #include "mlir/Dialect/Vector/VectorOps.h"
200deeaacaSLei Zhang #include "mlir/IR/BuiltinTypes.h"
210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
220deeaacaSLei Zhang 
230deeaacaSLei Zhang using namespace mlir;
240deeaacaSLei Zhang 
250deeaacaSLei Zhang //===----------------------------------------------------------------------===//
260deeaacaSLei Zhang // Utility functions
270deeaacaSLei Zhang //===----------------------------------------------------------------------===//
280deeaacaSLei Zhang 
290deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result
300deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the
310deeaacaSLei Zhang /// subview op. For example
320deeaacaSLei Zhang ///
330deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32>
340deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
350deeaacaSLei Zhang ///          memref<4x4xf32, offset=?, strides=[?, ?]>
360deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
370deeaacaSLei Zhang ///
380deeaacaSLei Zhang /// could be folded into
390deeaacaSLei Zhang ///
400deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
410deeaacaSLei Zhang ///          memref<12x42xf32>
420deeaacaSLei Zhang static LogicalResult
430deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter,
440deeaacaSLei Zhang                      memref::SubViewOp subViewOp, ValueRange indices,
450deeaacaSLei Zhang                      SmallVectorImpl<Value> &sourceIndices) {
46fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
47fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
48fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
490deeaacaSLei Zhang 
50fd15e2b8SMaheshRavishankar   SmallVector<Value> useIndices;
51fd15e2b8SMaheshRavishankar   // Check if this is rank-reducing case. Then for every unit-dim size add a
52fd15e2b8SMaheshRavishankar   // zero to the indices.
53fd15e2b8SMaheshRavishankar   unsigned resultDim = 0;
544cf9bf6cSMaheshRavishankar   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
554cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
564cf9bf6cSMaheshRavishankar     if (unusedDims.count(dim))
57*a54f4eaeSMogball       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
584cf9bf6cSMaheshRavishankar     else
59fd15e2b8SMaheshRavishankar       useIndices.push_back(indices[resultDim++]);
60fd15e2b8SMaheshRavishankar   }
61fd15e2b8SMaheshRavishankar   if (useIndices.size() != mixedOffsets.size())
62fd15e2b8SMaheshRavishankar     return failure();
63fd15e2b8SMaheshRavishankar   sourceIndices.resize(useIndices.size());
64fd15e2b8SMaheshRavishankar   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
65fd15e2b8SMaheshRavishankar     SmallVector<Value> dynamicOperands;
66fd15e2b8SMaheshRavishankar     AffineExpr expr = rewriter.getAffineDimExpr(0);
67fd15e2b8SMaheshRavishankar     unsigned numSymbols = 0;
68fd15e2b8SMaheshRavishankar     dynamicOperands.push_back(useIndices[index]);
69fd15e2b8SMaheshRavishankar 
70fd15e2b8SMaheshRavishankar     // Multiply the stride;
71fd15e2b8SMaheshRavishankar     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
72fd15e2b8SMaheshRavishankar       expr = expr * attr.cast<IntegerAttr>().getInt();
73fd15e2b8SMaheshRavishankar     } else {
74fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedStrides[index].get<Value>());
75fd15e2b8SMaheshRavishankar       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
76fd15e2b8SMaheshRavishankar     }
77fd15e2b8SMaheshRavishankar 
78fd15e2b8SMaheshRavishankar     // Add the offset.
79fd15e2b8SMaheshRavishankar     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
80fd15e2b8SMaheshRavishankar       expr = expr + attr.cast<IntegerAttr>().getInt();
81fd15e2b8SMaheshRavishankar     } else {
82fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
83fd15e2b8SMaheshRavishankar       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
84fd15e2b8SMaheshRavishankar     }
85fd15e2b8SMaheshRavishankar     Location loc = subViewOp.getLoc();
86fd15e2b8SMaheshRavishankar     sourceIndices[index] = rewriter.create<AffineApplyOp>(
87fd15e2b8SMaheshRavishankar         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
880deeaacaSLei Zhang   }
890deeaacaSLei Zhang   return success();
900deeaacaSLei Zhang }
910deeaacaSLei Zhang 
920deeaacaSLei Zhang /// Helpers to access the memref operand for each op.
930deeaacaSLei Zhang static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
940deeaacaSLei Zhang 
950deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
960deeaacaSLei Zhang 
970deeaacaSLei Zhang static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
980deeaacaSLei Zhang 
990deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) {
1000deeaacaSLei Zhang   return op.source();
1010deeaacaSLei Zhang }
1020deeaacaSLei Zhang 
1034cf9bf6cSMaheshRavishankar /// Given the permutation map of the original
1044cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the
1054cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it.
1064cf9bf6cSMaheshRavishankar static AffineMap getPermutationMap(MLIRContext *context,
1074cf9bf6cSMaheshRavishankar                                    memref::SubViewOp subViewOp,
1084cf9bf6cSMaheshRavishankar                                    AffineMap currPermutationMap) {
1094cf9bf6cSMaheshRavishankar   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1104cf9bf6cSMaheshRavishankar   SmallVector<AffineExpr> exprs;
1114cf9bf6cSMaheshRavishankar   int64_t sourceRank = subViewOp.getSourceType().getRank();
1124cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
1134cf9bf6cSMaheshRavishankar     if (unusedDims.count(dim))
1144cf9bf6cSMaheshRavishankar       continue;
115b12e4c17Sthomasraoux     exprs.push_back(getAffineDimExpr(dim, context));
1164cf9bf6cSMaheshRavishankar   }
1174cf9bf6cSMaheshRavishankar   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
1184cf9bf6cSMaheshRavishankar   return currPermutationMap.compose(resultDimToSourceDimMap);
1194cf9bf6cSMaheshRavishankar }
1204cf9bf6cSMaheshRavishankar 
1210deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1220deeaacaSLei Zhang // Patterns
1230deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1240deeaacaSLei Zhang 
1250deeaacaSLei Zhang namespace {
1260deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation.
1270deeaacaSLei Zhang template <typename OpTy>
1280deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1290deeaacaSLei Zhang public:
1300deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1310deeaacaSLei Zhang 
1320deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy loadOp,
1330deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1340deeaacaSLei Zhang 
1350deeaacaSLei Zhang private:
1360deeaacaSLei Zhang   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
1370deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1380deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1390deeaacaSLei Zhang };
1400deeaacaSLei Zhang 
1410deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation.
1420deeaacaSLei Zhang template <typename OpTy>
1430deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1440deeaacaSLei Zhang public:
1450deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1460deeaacaSLei Zhang 
1470deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy storeOp,
1480deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1490deeaacaSLei Zhang 
1500deeaacaSLei Zhang private:
1510deeaacaSLei Zhang   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
1520deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1530deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1540deeaacaSLei Zhang };
1550deeaacaSLei Zhang 
1560deeaacaSLei Zhang template <>
1570deeaacaSLei Zhang void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
1580deeaacaSLei Zhang     memref::LoadOp loadOp, memref::SubViewOp subViewOp,
1590deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
1600deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
1610deeaacaSLei Zhang                                               sourceIndices);
1620deeaacaSLei Zhang }
1630deeaacaSLei Zhang 
1640deeaacaSLei Zhang template <>
1650deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
1660deeaacaSLei Zhang     vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
1670deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
1680deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
1690deeaacaSLei Zhang       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
1704cf9bf6cSMaheshRavishankar       getPermutationMap(rewriter.getContext(), subViewOp,
1714cf9bf6cSMaheshRavishankar                         loadOp.permutation_map()),
1724cf9bf6cSMaheshRavishankar       loadOp.padding(), loadOp.in_boundsAttr());
1730deeaacaSLei Zhang }
1740deeaacaSLei Zhang 
1750deeaacaSLei Zhang template <>
1760deeaacaSLei Zhang void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
1770deeaacaSLei Zhang     memref::StoreOp storeOp, memref::SubViewOp subViewOp,
1780deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
1790deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<memref::StoreOp>(
1800deeaacaSLei Zhang       storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
1810deeaacaSLei Zhang }
1820deeaacaSLei Zhang 
1830deeaacaSLei Zhang template <>
1840deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
1850deeaacaSLei Zhang     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
1860deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
1870deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1880deeaacaSLei Zhang       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
1894cf9bf6cSMaheshRavishankar       sourceIndices,
1904cf9bf6cSMaheshRavishankar       getPermutationMap(rewriter.getContext(), subViewOp,
1914cf9bf6cSMaheshRavishankar                         transferWriteOp.permutation_map()),
1920deeaacaSLei Zhang       transferWriteOp.in_boundsAttr());
1930deeaacaSLei Zhang }
1940deeaacaSLei Zhang } // namespace
1950deeaacaSLei Zhang 
1960deeaacaSLei Zhang template <typename OpTy>
1970deeaacaSLei Zhang LogicalResult
1980deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
1990deeaacaSLei Zhang                                              PatternRewriter &rewriter) const {
2000deeaacaSLei Zhang   auto subViewOp =
2010deeaacaSLei Zhang       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
2020deeaacaSLei Zhang   if (!subViewOp)
2030deeaacaSLei Zhang     return failure();
2040deeaacaSLei Zhang 
2050deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
2060deeaacaSLei Zhang   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
2070deeaacaSLei Zhang                                   loadOp.indices(), sourceIndices)))
2080deeaacaSLei Zhang     return failure();
2090deeaacaSLei Zhang 
2100deeaacaSLei Zhang   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
2110deeaacaSLei Zhang   return success();
2120deeaacaSLei Zhang }
2130deeaacaSLei Zhang 
2140deeaacaSLei Zhang template <typename OpTy>
2150deeaacaSLei Zhang LogicalResult
2160deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
2170deeaacaSLei Zhang                                               PatternRewriter &rewriter) const {
2180deeaacaSLei Zhang   auto subViewOp =
2190deeaacaSLei Zhang       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
2200deeaacaSLei Zhang   if (!subViewOp)
2210deeaacaSLei Zhang     return failure();
2220deeaacaSLei Zhang 
2230deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
2240deeaacaSLei Zhang   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
2250deeaacaSLei Zhang                                   storeOp.indices(), sourceIndices)))
2260deeaacaSLei Zhang     return failure();
2270deeaacaSLei Zhang 
2280deeaacaSLei Zhang   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
2290deeaacaSLei Zhang   return success();
2300deeaacaSLei Zhang }
2310deeaacaSLei Zhang 
2320deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
2330deeaacaSLei Zhang   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
2340deeaacaSLei Zhang                LoadOpOfSubViewFolder<vector::TransferReadOp>,
2350deeaacaSLei Zhang                StoreOpOfSubViewFolder<memref::StoreOp>,
2360deeaacaSLei Zhang                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
2370deeaacaSLei Zhang       patterns.getContext());
2380deeaacaSLei Zhang }
2390deeaacaSLei Zhang 
2400deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2410deeaacaSLei Zhang // Pass registration
2420deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2430deeaacaSLei Zhang 
2440deeaacaSLei Zhang namespace {
2450deeaacaSLei Zhang 
2460deeaacaSLei Zhang #define GEN_PASS_CLASSES
2470deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
2480deeaacaSLei Zhang 
2490deeaacaSLei Zhang struct FoldSubViewOpsPass final
2500deeaacaSLei Zhang     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
2510deeaacaSLei Zhang   void runOnOperation() override;
2520deeaacaSLei Zhang };
2530deeaacaSLei Zhang 
2540deeaacaSLei Zhang } // namespace
2550deeaacaSLei Zhang 
2560deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() {
2570deeaacaSLei Zhang   RewritePatternSet patterns(&getContext());
2580deeaacaSLei Zhang   memref::populateFoldSubViewOpPatterns(patterns);
2590deeaacaSLei Zhang   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
2600deeaacaSLei Zhang                                      std::move(patterns));
2610deeaacaSLei Zhang }
2620deeaacaSLei Zhang 
2630deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
2640deeaacaSLei Zhang   return std::make_unique<FoldSubViewOpsPass>();
2650deeaacaSLei Zhang }
266