1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Given the 'indices' of an load/store operation where the memref is a result
28 /// of a subview op, returns the indices w.r.t to the source memref of the
29 /// subview op. For example
30 ///
31 /// %0 = ... : memref<12x42xf32>
32 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
33 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
34 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
35 ///
36 /// could be folded into
37 ///
38 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
39 ///          memref<12x42xf32>
40 static LogicalResult
41 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
42                      memref::SubViewOp subViewOp, ValueRange indices,
43                      SmallVectorImpl<Value> &sourceIndices) {
44   // TODO: Aborting when the offsets are static. There might be a way to fold
45   // the subview op with load even if the offsets have been canonicalized
46   // away.
47   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
48   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
49   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
50   assert(opRanges.size() == indices.size() &&
51          "expected as many indices as rank of subview op result type");
52 
53   // New indices for the load are the current indices * subview_stride +
54   // subview_offset.
55   sourceIndices.resize(indices.size());
56   for (auto index : llvm::enumerate(indices)) {
57     auto offset = *(opOffsets.begin() + index.index());
58     auto stride = *(opStrides.begin() + index.index());
59     auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
60     sourceIndices[index.index()] =
61         rewriter.create<AddIOp>(loc, offset, mul).getResult();
62   }
63   return success();
64 }
65 
66 /// Helpers to access the memref operand for each op.
67 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
68 
69 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
70 
71 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
72 
73 static Value getMemRefOperand(vector::TransferWriteOp op) {
74   return op.source();
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Patterns
79 //===----------------------------------------------------------------------===//
80 
81 namespace {
82 /// Merges subview operation with load/transferRead operation.
83 template <typename OpTy>
84 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
85 public:
86   using OpRewritePattern<OpTy>::OpRewritePattern;
87 
88   LogicalResult matchAndRewrite(OpTy loadOp,
89                                 PatternRewriter &rewriter) const override;
90 
91 private:
92   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
93                  ArrayRef<Value> sourceIndices,
94                  PatternRewriter &rewriter) const;
95 };
96 
97 /// Merges subview operation with store/transferWriteOp operation.
98 template <typename OpTy>
99 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
100 public:
101   using OpRewritePattern<OpTy>::OpRewritePattern;
102 
103   LogicalResult matchAndRewrite(OpTy storeOp,
104                                 PatternRewriter &rewriter) const override;
105 
106 private:
107   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
108                  ArrayRef<Value> sourceIndices,
109                  PatternRewriter &rewriter) const;
110 };
111 
112 template <>
113 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
114     memref::LoadOp loadOp, memref::SubViewOp subViewOp,
115     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
116   rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
117                                               sourceIndices);
118 }
119 
120 template <>
121 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
122     vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
123     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
124   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
125       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
126       loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
127 }
128 
129 template <>
130 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
131     memref::StoreOp storeOp, memref::SubViewOp subViewOp,
132     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
133   rewriter.replaceOpWithNewOp<memref::StoreOp>(
134       storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
135 }
136 
137 template <>
138 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
139     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
140     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
141   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
142       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
143       sourceIndices, transferWriteOp.permutation_map(),
144       transferWriteOp.in_boundsAttr());
145 }
146 } // namespace
147 
148 template <typename OpTy>
149 LogicalResult
150 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
151                                              PatternRewriter &rewriter) const {
152   auto subViewOp =
153       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
154   if (!subViewOp)
155     return failure();
156 
157   SmallVector<Value, 4> sourceIndices;
158   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
159                                   loadOp.indices(), sourceIndices)))
160     return failure();
161 
162   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
163   return success();
164 }
165 
166 template <typename OpTy>
167 LogicalResult
168 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
169                                               PatternRewriter &rewriter) const {
170   auto subViewOp =
171       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
172   if (!subViewOp)
173     return failure();
174 
175   SmallVector<Value, 4> sourceIndices;
176   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
177                                   storeOp.indices(), sourceIndices)))
178     return failure();
179 
180   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
181   return success();
182 }
183 
184 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
185   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
186                LoadOpOfSubViewFolder<vector::TransferReadOp>,
187                StoreOpOfSubViewFolder<memref::StoreOp>,
188                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
189       patterns.getContext());
190 }
191 
192 //===----------------------------------------------------------------------===//
193 // Pass registration
194 //===----------------------------------------------------------------------===//
195 
196 namespace {
197 
198 #define GEN_PASS_CLASSES
199 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
200 
201 struct FoldSubViewOpsPass final
202     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
203   void runOnOperation() override;
204 };
205 
206 } // namespace
207 
208 void FoldSubViewOpsPass::runOnOperation() {
209   RewritePatternSet patterns(&getContext());
210   memref::populateFoldSubViewOpPatterns(patterns);
211   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
212                                      std::move(patterns));
213 }
214 
215 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
216   return std::make_unique<FoldSubViewOpsPass>();
217 }
218