1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/MemRef/IR/MemRef.h"
15 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/Dialect/Vector/VectorOps.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Utility functions
25 //===----------------------------------------------------------------------===//
26 
27 /// Given the 'indices' of an load/store operation where the memref is a result
28 /// of a subview op, returns the indices w.r.t to the source memref of the
29 /// subview op. For example
30 ///
31 /// %0 = ... : memref<12x42xf32>
32 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
33 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
34 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
35 ///
36 /// could be folded into
37 ///
38 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
39 ///          memref<12x42xf32>
40 static LogicalResult
41 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
42                      memref::SubViewOp subViewOp, ValueRange indices,
43                      SmallVectorImpl<Value> &sourceIndices) {
44   // TODO: Aborting when the offsets are static. There might be a way to fold
45   // the subview op with load even if the offsets have been canonicalized
46   // away.
47   SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc);
48   if (opRanges.size() != indices.size()) {
49     // For the rank-reduced cases, we can only handle the folding when the
50     // offset is zero, size is 1 and stride is 1.
51     return failure();
52   }
53   auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; });
54   auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; });
55 
56   // New indices for the load are the current indices * subview_stride +
57   // subview_offset.
58   sourceIndices.resize(indices.size());
59   for (auto index : llvm::enumerate(indices)) {
60     auto offset = *(opOffsets.begin() + index.index());
61     auto stride = *(opStrides.begin() + index.index());
62     auto mul = rewriter.create<MulIOp>(loc, index.value(), stride);
63     sourceIndices[index.index()] =
64         rewriter.create<AddIOp>(loc, offset, mul).getResult();
65   }
66   return success();
67 }
68 
69 /// Helpers to access the memref operand for each op.
70 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
71 
72 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
73 
74 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
75 
76 static Value getMemRefOperand(vector::TransferWriteOp op) {
77   return op.source();
78 }
79 
80 //===----------------------------------------------------------------------===//
81 // Patterns
82 //===----------------------------------------------------------------------===//
83 
84 namespace {
85 /// Merges subview operation with load/transferRead operation.
86 template <typename OpTy>
87 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
88 public:
89   using OpRewritePattern<OpTy>::OpRewritePattern;
90 
91   LogicalResult matchAndRewrite(OpTy loadOp,
92                                 PatternRewriter &rewriter) const override;
93 
94 private:
95   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
96                  ArrayRef<Value> sourceIndices,
97                  PatternRewriter &rewriter) const;
98 };
99 
100 /// Merges subview operation with store/transferWriteOp operation.
101 template <typename OpTy>
102 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
103 public:
104   using OpRewritePattern<OpTy>::OpRewritePattern;
105 
106   LogicalResult matchAndRewrite(OpTy storeOp,
107                                 PatternRewriter &rewriter) const override;
108 
109 private:
110   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
111                  ArrayRef<Value> sourceIndices,
112                  PatternRewriter &rewriter) const;
113 };
114 
115 template <>
116 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
117     memref::LoadOp loadOp, memref::SubViewOp subViewOp,
118     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
119   rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
120                                               sourceIndices);
121 }
122 
123 template <>
124 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
125     vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
126     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
127   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
128       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
129       loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
130 }
131 
132 template <>
133 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
134     memref::StoreOp storeOp, memref::SubViewOp subViewOp,
135     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
136   rewriter.replaceOpWithNewOp<memref::StoreOp>(
137       storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
138 }
139 
140 template <>
141 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
142     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
143     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
144   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
145       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
146       sourceIndices, transferWriteOp.permutation_map(),
147       transferWriteOp.in_boundsAttr());
148 }
149 } // namespace
150 
151 template <typename OpTy>
152 LogicalResult
153 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
154                                              PatternRewriter &rewriter) const {
155   auto subViewOp =
156       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
157   if (!subViewOp)
158     return failure();
159 
160   SmallVector<Value, 4> sourceIndices;
161   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
162                                   loadOp.indices(), sourceIndices)))
163     return failure();
164 
165   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
166   return success();
167 }
168 
169 template <typename OpTy>
170 LogicalResult
171 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
172                                               PatternRewriter &rewriter) const {
173   auto subViewOp =
174       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
175   if (!subViewOp)
176     return failure();
177 
178   SmallVector<Value, 4> sourceIndices;
179   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
180                                   storeOp.indices(), sourceIndices)))
181     return failure();
182 
183   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
184   return success();
185 }
186 
187 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
188   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
189                LoadOpOfSubViewFolder<vector::TransferReadOp>,
190                StoreOpOfSubViewFolder<memref::StoreOp>,
191                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
192       patterns.getContext());
193 }
194 
195 //===----------------------------------------------------------------------===//
196 // Pass registration
197 //===----------------------------------------------------------------------===//
198 
199 namespace {
200 
201 #define GEN_PASS_CLASSES
202 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
203 
204 struct FoldSubViewOpsPass final
205     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
206   void runOnOperation() override;
207 };
208 
209 } // namespace
210 
211 void FoldSubViewOpsPass::runOnOperation() {
212   RewritePatternSet patterns(&getContext());
213   memref::populateFoldSubViewOpPatterns(patterns);
214   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
215                                      std::move(patterns));
216 }
217 
218 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
219   return std::make_unique<FoldSubViewOpsPass>();
220 }
221