10deeaacaSLei Zhang //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
20deeaacaSLei Zhang //
30deeaacaSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
40deeaacaSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
50deeaacaSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
60deeaacaSLei Zhang //
70deeaacaSLei Zhang //===----------------------------------------------------------------------===//
80deeaacaSLei Zhang //
90deeaacaSLei Zhang // This transformation pass folds loading/storing from/to subview ops into
100deeaacaSLei Zhang // loading/storing from/to the original memref.
110deeaacaSLei Zhang //
120deeaacaSLei Zhang //===----------------------------------------------------------------------===//
130deeaacaSLei Zhang 
14fd15e2b8SMaheshRavishankar #include "mlir/Dialect/Affine/IR/AffineOps.h"
15a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
160deeaacaSLei Zhang #include "mlir/Dialect/MemRef/IR/MemRef.h"
170deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h"
180deeaacaSLei Zhang #include "mlir/Dialect/StandardOps/IR/Ops.h"
190deeaacaSLei Zhang #include "mlir/Dialect/Vector/VectorOps.h"
200deeaacaSLei Zhang #include "mlir/IR/BuiltinTypes.h"
210deeaacaSLei Zhang #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
220deeaacaSLei Zhang 
230deeaacaSLei Zhang using namespace mlir;
240deeaacaSLei Zhang 
250deeaacaSLei Zhang //===----------------------------------------------------------------------===//
260deeaacaSLei Zhang // Utility functions
270deeaacaSLei Zhang //===----------------------------------------------------------------------===//
280deeaacaSLei Zhang 
290deeaacaSLei Zhang /// Given the 'indices' of an load/store operation where the memref is a result
300deeaacaSLei Zhang /// of a subview op, returns the indices w.r.t to the source memref of the
310deeaacaSLei Zhang /// subview op. For example
320deeaacaSLei Zhang ///
330deeaacaSLei Zhang /// %0 = ... : memref<12x42xf32>
340deeaacaSLei Zhang /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
350deeaacaSLei Zhang ///          memref<4x4xf32, offset=?, strides=[?, ?]>
360deeaacaSLei Zhang /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
370deeaacaSLei Zhang ///
380deeaacaSLei Zhang /// could be folded into
390deeaacaSLei Zhang ///
400deeaacaSLei Zhang /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
410deeaacaSLei Zhang ///          memref<12x42xf32>
420deeaacaSLei Zhang static LogicalResult
430deeaacaSLei Zhang resolveSourceIndices(Location loc, PatternRewriter &rewriter,
440deeaacaSLei Zhang                      memref::SubViewOp subViewOp, ValueRange indices,
450deeaacaSLei Zhang                      SmallVectorImpl<Value> &sourceIndices) {
46fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
47fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
48fd15e2b8SMaheshRavishankar   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
490deeaacaSLei Zhang 
50fd15e2b8SMaheshRavishankar   SmallVector<Value> useIndices;
51fd15e2b8SMaheshRavishankar   // Check if this is rank-reducing case. Then for every unit-dim size add a
52fd15e2b8SMaheshRavishankar   // zero to the indices.
53fd15e2b8SMaheshRavishankar   unsigned resultDim = 0;
544cf9bf6cSMaheshRavishankar   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
554cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
564cf9bf6cSMaheshRavishankar     if (unusedDims.count(dim))
57a54f4eaeSMogball       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
584cf9bf6cSMaheshRavishankar     else
59fd15e2b8SMaheshRavishankar       useIndices.push_back(indices[resultDim++]);
60fd15e2b8SMaheshRavishankar   }
61fd15e2b8SMaheshRavishankar   if (useIndices.size() != mixedOffsets.size())
62fd15e2b8SMaheshRavishankar     return failure();
63fd15e2b8SMaheshRavishankar   sourceIndices.resize(useIndices.size());
64fd15e2b8SMaheshRavishankar   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
65fd15e2b8SMaheshRavishankar     SmallVector<Value> dynamicOperands;
66fd15e2b8SMaheshRavishankar     AffineExpr expr = rewriter.getAffineDimExpr(0);
67fd15e2b8SMaheshRavishankar     unsigned numSymbols = 0;
68fd15e2b8SMaheshRavishankar     dynamicOperands.push_back(useIndices[index]);
69fd15e2b8SMaheshRavishankar 
70fd15e2b8SMaheshRavishankar     // Multiply the stride;
71fd15e2b8SMaheshRavishankar     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
72fd15e2b8SMaheshRavishankar       expr = expr * attr.cast<IntegerAttr>().getInt();
73fd15e2b8SMaheshRavishankar     } else {
74fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedStrides[index].get<Value>());
75fd15e2b8SMaheshRavishankar       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
76fd15e2b8SMaheshRavishankar     }
77fd15e2b8SMaheshRavishankar 
78fd15e2b8SMaheshRavishankar     // Add the offset.
79fd15e2b8SMaheshRavishankar     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
80fd15e2b8SMaheshRavishankar       expr = expr + attr.cast<IntegerAttr>().getInt();
81fd15e2b8SMaheshRavishankar     } else {
82fd15e2b8SMaheshRavishankar       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
83fd15e2b8SMaheshRavishankar       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
84fd15e2b8SMaheshRavishankar     }
85fd15e2b8SMaheshRavishankar     Location loc = subViewOp.getLoc();
86fd15e2b8SMaheshRavishankar     sourceIndices[index] = rewriter.create<AffineApplyOp>(
87fd15e2b8SMaheshRavishankar         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
880deeaacaSLei Zhang   }
890deeaacaSLei Zhang   return success();
900deeaacaSLei Zhang }
910deeaacaSLei Zhang 
920deeaacaSLei Zhang /// Helpers to access the memref operand for each op.
93*f8a2cd67SUday Bondhugula template <typename LoadOrStoreOpTy>
94*f8a2cd67SUday Bondhugula static Value getMemRefOperand(LoadOrStoreOpTy op) {
95*f8a2cd67SUday Bondhugula   return op.memref();
96*f8a2cd67SUday Bondhugula }
970deeaacaSLei Zhang 
980deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
990deeaacaSLei Zhang 
1000deeaacaSLei Zhang static Value getMemRefOperand(vector::TransferWriteOp op) {
1010deeaacaSLei Zhang   return op.source();
1020deeaacaSLei Zhang }
1030deeaacaSLei Zhang 
1044cf9bf6cSMaheshRavishankar /// Given the permutation map of the original
1054cf9bf6cSMaheshRavishankar /// `vector.transfer_read`/`vector.transfer_write` operations compute the
1064cf9bf6cSMaheshRavishankar /// permutation map to use after the subview is folded with it.
107c537a943SNicolas Vasilache static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
1084cf9bf6cSMaheshRavishankar                                            memref::SubViewOp subViewOp,
1094cf9bf6cSMaheshRavishankar                                            AffineMap currPermutationMap) {
1104cf9bf6cSMaheshRavishankar   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
1114cf9bf6cSMaheshRavishankar   SmallVector<AffineExpr> exprs;
1124cf9bf6cSMaheshRavishankar   int64_t sourceRank = subViewOp.getSourceType().getRank();
1134cf9bf6cSMaheshRavishankar   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
1144cf9bf6cSMaheshRavishankar     if (unusedDims.count(dim))
1154cf9bf6cSMaheshRavishankar       continue;
116b12e4c17Sthomasraoux     exprs.push_back(getAffineDimExpr(dim, context));
1174cf9bf6cSMaheshRavishankar   }
1184cf9bf6cSMaheshRavishankar   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
119c537a943SNicolas Vasilache   return AffineMapAttr::get(
120c537a943SNicolas Vasilache       currPermutationMap.compose(resultDimToSourceDimMap));
1214cf9bf6cSMaheshRavishankar }
1224cf9bf6cSMaheshRavishankar 
1230deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1240deeaacaSLei Zhang // Patterns
1250deeaacaSLei Zhang //===----------------------------------------------------------------------===//
1260deeaacaSLei Zhang 
1270deeaacaSLei Zhang namespace {
1280deeaacaSLei Zhang /// Merges subview operation with load/transferRead operation.
1290deeaacaSLei Zhang template <typename OpTy>
1300deeaacaSLei Zhang class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1310deeaacaSLei Zhang public:
1320deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1330deeaacaSLei Zhang 
1340deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy loadOp,
1350deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1360deeaacaSLei Zhang 
1370deeaacaSLei Zhang private:
1380deeaacaSLei Zhang   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
1390deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1400deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1410deeaacaSLei Zhang };
1420deeaacaSLei Zhang 
1430deeaacaSLei Zhang /// Merges subview operation with store/transferWriteOp operation.
1440deeaacaSLei Zhang template <typename OpTy>
1450deeaacaSLei Zhang class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
1460deeaacaSLei Zhang public:
1470deeaacaSLei Zhang   using OpRewritePattern<OpTy>::OpRewritePattern;
1480deeaacaSLei Zhang 
1490deeaacaSLei Zhang   LogicalResult matchAndRewrite(OpTy storeOp,
1500deeaacaSLei Zhang                                 PatternRewriter &rewriter) const override;
1510deeaacaSLei Zhang 
1520deeaacaSLei Zhang private:
1530deeaacaSLei Zhang   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
1540deeaacaSLei Zhang                  ArrayRef<Value> sourceIndices,
1550deeaacaSLei Zhang                  PatternRewriter &rewriter) const;
1560deeaacaSLei Zhang };
1570deeaacaSLei Zhang 
158*f8a2cd67SUday Bondhugula template <typename LoadOpTy>
159*f8a2cd67SUday Bondhugula void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
160*f8a2cd67SUday Bondhugula     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
161*f8a2cd67SUday Bondhugula     PatternRewriter &rewriter) const {
162*f8a2cd67SUday Bondhugula   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
1630deeaacaSLei Zhang                                         sourceIndices);
1640deeaacaSLei Zhang }
1650deeaacaSLei Zhang 
1660deeaacaSLei Zhang template <>
1670deeaacaSLei Zhang void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
168c537a943SNicolas Vasilache     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
1690deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
170c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
171c537a943SNicolas Vasilache   if (transferReadOp.getTransferRank() == 0)
172c537a943SNicolas Vasilache     return;
1730deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
174c537a943SNicolas Vasilache       transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
175c537a943SNicolas Vasilache       sourceIndices,
176c537a943SNicolas Vasilache       getPermutationMapAttr(rewriter.getContext(), subViewOp,
177c537a943SNicolas Vasilache                             transferReadOp.permutation_map()),
178c537a943SNicolas Vasilache       transferReadOp.padding(),
179c537a943SNicolas Vasilache       /*mask=*/Value(), transferReadOp.in_boundsAttr());
1800deeaacaSLei Zhang }
1810deeaacaSLei Zhang 
182*f8a2cd67SUday Bondhugula template <typename StoreOpTy>
183*f8a2cd67SUday Bondhugula void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
184*f8a2cd67SUday Bondhugula     StoreOpTy storeOp, memref::SubViewOp subViewOp,
1850deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
186*f8a2cd67SUday Bondhugula   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
187*f8a2cd67SUday Bondhugula                                          subViewOp.source(), sourceIndices);
1880deeaacaSLei Zhang }
1890deeaacaSLei Zhang 
1900deeaacaSLei Zhang template <>
1910deeaacaSLei Zhang void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
1920deeaacaSLei Zhang     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
1930deeaacaSLei Zhang     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
194c537a943SNicolas Vasilache   // TODO: support 0-d corner case.
195c537a943SNicolas Vasilache   if (transferWriteOp.getTransferRank() == 0)
196c537a943SNicolas Vasilache     return;
1970deeaacaSLei Zhang   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
1980deeaacaSLei Zhang       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
1994cf9bf6cSMaheshRavishankar       sourceIndices,
200c537a943SNicolas Vasilache       getPermutationMapAttr(rewriter.getContext(), subViewOp,
2014cf9bf6cSMaheshRavishankar                             transferWriteOp.permutation_map()),
2020deeaacaSLei Zhang       transferWriteOp.in_boundsAttr());
2030deeaacaSLei Zhang }
2040deeaacaSLei Zhang } // namespace
2050deeaacaSLei Zhang 
2060deeaacaSLei Zhang template <typename OpTy>
2070deeaacaSLei Zhang LogicalResult
2080deeaacaSLei Zhang LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
2090deeaacaSLei Zhang                                              PatternRewriter &rewriter) const {
2100deeaacaSLei Zhang   auto subViewOp =
2110deeaacaSLei Zhang       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
2120deeaacaSLei Zhang   if (!subViewOp)
2130deeaacaSLei Zhang     return failure();
2140deeaacaSLei Zhang 
2150deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
2160deeaacaSLei Zhang   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
2170deeaacaSLei Zhang                                   loadOp.indices(), sourceIndices)))
2180deeaacaSLei Zhang     return failure();
2190deeaacaSLei Zhang 
2200deeaacaSLei Zhang   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
2210deeaacaSLei Zhang   return success();
2220deeaacaSLei Zhang }
2230deeaacaSLei Zhang 
2240deeaacaSLei Zhang template <typename OpTy>
2250deeaacaSLei Zhang LogicalResult
2260deeaacaSLei Zhang StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
2270deeaacaSLei Zhang                                               PatternRewriter &rewriter) const {
2280deeaacaSLei Zhang   auto subViewOp =
2290deeaacaSLei Zhang       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
2300deeaacaSLei Zhang   if (!subViewOp)
2310deeaacaSLei Zhang     return failure();
2320deeaacaSLei Zhang 
2330deeaacaSLei Zhang   SmallVector<Value, 4> sourceIndices;
2340deeaacaSLei Zhang   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
2350deeaacaSLei Zhang                                   storeOp.indices(), sourceIndices)))
2360deeaacaSLei Zhang     return failure();
2370deeaacaSLei Zhang 
2380deeaacaSLei Zhang   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
2390deeaacaSLei Zhang   return success();
2400deeaacaSLei Zhang }
2410deeaacaSLei Zhang 
2420deeaacaSLei Zhang void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
243*f8a2cd67SUday Bondhugula   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
244*f8a2cd67SUday Bondhugula                LoadOpOfSubViewFolder<memref::LoadOp>,
2450deeaacaSLei Zhang                LoadOpOfSubViewFolder<vector::TransferReadOp>,
246*f8a2cd67SUday Bondhugula                StoreOpOfSubViewFolder<AffineStoreOp>,
2470deeaacaSLei Zhang                StoreOpOfSubViewFolder<memref::StoreOp>,
2480deeaacaSLei Zhang                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
2490deeaacaSLei Zhang       patterns.getContext());
2500deeaacaSLei Zhang }
2510deeaacaSLei Zhang 
2520deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2530deeaacaSLei Zhang // Pass registration
2540deeaacaSLei Zhang //===----------------------------------------------------------------------===//
2550deeaacaSLei Zhang 
2560deeaacaSLei Zhang namespace {
2570deeaacaSLei Zhang 
2580deeaacaSLei Zhang #define GEN_PASS_CLASSES
2590deeaacaSLei Zhang #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
2600deeaacaSLei Zhang 
2610deeaacaSLei Zhang struct FoldSubViewOpsPass final
2620deeaacaSLei Zhang     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
2630deeaacaSLei Zhang   void runOnOperation() override;
2640deeaacaSLei Zhang };
2650deeaacaSLei Zhang 
2660deeaacaSLei Zhang } // namespace
2670deeaacaSLei Zhang 
2680deeaacaSLei Zhang void FoldSubViewOpsPass::runOnOperation() {
2690deeaacaSLei Zhang   RewritePatternSet patterns(&getContext());
2700deeaacaSLei Zhang   memref::populateFoldSubViewOpPatterns(patterns);
2710deeaacaSLei Zhang   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
2720deeaacaSLei Zhang                                      std::move(patterns));
2730deeaacaSLei Zhang }
2740deeaacaSLei Zhang 
2750deeaacaSLei Zhang std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
2760deeaacaSLei Zhang   return std::make_unique<FoldSubViewOpsPass>();
2770deeaacaSLei Zhang }
278