1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Given the 'indices' of an load/store operation where the memref is a result
29 /// of a subview op, returns the indices w.r.t to the source memref of the
30 /// subview op. For example
31 ///
32 /// %0 = ... : memref<12x42xf32>
33 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
34 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
35 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
36 ///
37 /// could be folded into
38 ///
39 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
40 ///          memref<12x42xf32>
41 static LogicalResult
42 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
43                      memref::SubViewOp subViewOp, ValueRange indices,
44                      SmallVectorImpl<Value> &sourceIndices) {
45   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
46   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
47   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
48 
49   SmallVector<Value> useIndices;
50   // Check if this is rank-reducing case. Then for every unit-dim size add a
51   // zero to the indices.
52   ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
53   unsigned resultDim = 0;
54   for (auto size : llvm::enumerate(mixedSizes)) {
55     auto attr = size.value().dyn_cast<Attribute>();
56     // Check if this dimension has been dropped, i.e. the size is 1, but the
57     // associated dimension is not 1.
58     if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
59         (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
60       useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
61     else if (resultDim < resultShape.size()) {
62       useIndices.push_back(indices[resultDim++]);
63     }
64   }
65   if (useIndices.size() != mixedOffsets.size())
66     return failure();
67   sourceIndices.resize(useIndices.size());
68   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
69     SmallVector<Value> dynamicOperands;
70     AffineExpr expr = rewriter.getAffineDimExpr(0);
71     unsigned numSymbols = 0;
72     dynamicOperands.push_back(useIndices[index]);
73 
74     // Multiply the stride;
75     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
76       expr = expr * attr.cast<IntegerAttr>().getInt();
77     } else {
78       dynamicOperands.push_back(mixedStrides[index].get<Value>());
79       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
80     }
81 
82     // Add the offset.
83     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
84       expr = expr + attr.cast<IntegerAttr>().getInt();
85     } else {
86       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
87       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
88     }
89     Location loc = subViewOp.getLoc();
90     sourceIndices[index] = rewriter.create<AffineApplyOp>(
91         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
92   }
93   return success();
94 }
95 
96 /// Helpers to access the memref operand for each op.
97 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
98 
99 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
100 
101 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
102 
103 static Value getMemRefOperand(vector::TransferWriteOp op) {
104   return op.source();
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // Patterns
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 /// Merges subview operation with load/transferRead operation.
113 template <typename OpTy>
114 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
115 public:
116   using OpRewritePattern<OpTy>::OpRewritePattern;
117 
118   LogicalResult matchAndRewrite(OpTy loadOp,
119                                 PatternRewriter &rewriter) const override;
120 
121 private:
122   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
123                  ArrayRef<Value> sourceIndices,
124                  PatternRewriter &rewriter) const;
125 };
126 
127 /// Merges subview operation with store/transferWriteOp operation.
128 template <typename OpTy>
129 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
130 public:
131   using OpRewritePattern<OpTy>::OpRewritePattern;
132 
133   LogicalResult matchAndRewrite(OpTy storeOp,
134                                 PatternRewriter &rewriter) const override;
135 
136 private:
137   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
138                  ArrayRef<Value> sourceIndices,
139                  PatternRewriter &rewriter) const;
140 };
141 
142 template <>
143 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
144     memref::LoadOp loadOp, memref::SubViewOp subViewOp,
145     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
146   rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
147                                               sourceIndices);
148 }
149 
150 template <>
151 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
152     vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
153     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
154   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
155       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
156       loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
157 }
158 
159 template <>
160 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
161     memref::StoreOp storeOp, memref::SubViewOp subViewOp,
162     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
163   rewriter.replaceOpWithNewOp<memref::StoreOp>(
164       storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
165 }
166 
167 template <>
168 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
169     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
170     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
171   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
172       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
173       sourceIndices, transferWriteOp.permutation_map(),
174       transferWriteOp.in_boundsAttr());
175 }
176 } // namespace
177 
178 template <typename OpTy>
179 LogicalResult
180 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
181                                              PatternRewriter &rewriter) const {
182   auto subViewOp =
183       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
184   if (!subViewOp)
185     return failure();
186 
187   SmallVector<Value, 4> sourceIndices;
188   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
189                                   loadOp.indices(), sourceIndices)))
190     return failure();
191 
192   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
193   return success();
194 }
195 
196 template <typename OpTy>
197 LogicalResult
198 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
199                                               PatternRewriter &rewriter) const {
200   auto subViewOp =
201       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
202   if (!subViewOp)
203     return failure();
204 
205   SmallVector<Value, 4> sourceIndices;
206   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
207                                   storeOp.indices(), sourceIndices)))
208     return failure();
209 
210   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
211   return success();
212 }
213 
214 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
215   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
216                LoadOpOfSubViewFolder<vector::TransferReadOp>,
217                StoreOpOfSubViewFolder<memref::StoreOp>,
218                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
219       patterns.getContext());
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // Pass registration
224 //===----------------------------------------------------------------------===//
225 
226 namespace {
227 
228 #define GEN_PASS_CLASSES
229 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
230 
231 struct FoldSubViewOpsPass final
232     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
233   void runOnOperation() override;
234 };
235 
236 } // namespace
237 
238 void FoldSubViewOpsPass::runOnOperation() {
239   RewritePatternSet patterns(&getContext());
240   memref::populateFoldSubViewOpPatterns(patterns);
241   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
242                                      std::move(patterns));
243 }
244 
245 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
246   return std::make_unique<FoldSubViewOpsPass>();
247 }
248