1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
14 
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Affine/Utils.h"
17 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
18 #include "mlir/Dialect/MemRef/IR/MemRef.h"
19 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
20 #include "mlir/Dialect/Vector/IR/VectorOps.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Given the 'indices' of an load/store operation where the memref is a result
31 /// of a subview op, returns the indices w.r.t to the source memref of the
32 /// subview op. For example
33 ///
34 /// %0 = ... : memref<12x42xf32>
35 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
36 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
37 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
38 ///
39 /// could be folded into
40 ///
41 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
42 ///          memref<12x42xf32>
43 static LogicalResult
44 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
45                      memref::SubViewOp subViewOp, ValueRange indices,
46                      SmallVectorImpl<Value> &sourceIndices) {
47   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
48   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
49   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
50 
51   SmallVector<Value> useIndices;
52   // Check if this is rank-reducing case. Then for every unit-dim size add a
53   // zero to the indices.
54   unsigned resultDim = 0;
55   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
56   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
57     if (unusedDims.test(dim))
58       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
59     else
60       useIndices.push_back(indices[resultDim++]);
61   }
62   if (useIndices.size() != mixedOffsets.size())
63     return failure();
64   sourceIndices.resize(useIndices.size());
65   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
66     SmallVector<Value> dynamicOperands;
67     AffineExpr expr = rewriter.getAffineDimExpr(0);
68     unsigned numSymbols = 0;
69     dynamicOperands.push_back(useIndices[index]);
70 
71     // Multiply the stride;
72     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
73       expr = expr * attr.cast<IntegerAttr>().getInt();
74     } else {
75       dynamicOperands.push_back(mixedStrides[index].get<Value>());
76       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
77     }
78 
79     // Add the offset.
80     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
81       expr = expr + attr.cast<IntegerAttr>().getInt();
82     } else {
83       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
84       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
85     }
86     Location loc = subViewOp.getLoc();
87     sourceIndices[index] = rewriter.create<AffineApplyOp>(
88         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
89   }
90   return success();
91 }
92 
93 /// Helpers to access the memref operand for each op.
94 template <typename LoadOrStoreOpTy>
95 static Value getMemRefOperand(LoadOrStoreOpTy op) {
96   return op.memref();
97 }
98 
99 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
100 
101 static Value getMemRefOperand(vector::TransferWriteOp op) {
102   return op.source();
103 }
104 
105 /// Given the permutation map of the original
106 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
107 /// permutation map to use after the subview is folded with it.
108 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
109                                            memref::SubViewOp subViewOp,
110                                            AffineMap currPermutationMap) {
111   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
112   SmallVector<AffineExpr> exprs;
113   int64_t sourceRank = subViewOp.getSourceType().getRank();
114   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
115     if (unusedDims.test(dim))
116       continue;
117     exprs.push_back(getAffineDimExpr(dim, context));
118   }
119   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
120   return AffineMapAttr::get(
121       currPermutationMap.compose(resultDimToSourceDimMap));
122 }
123 
124 //===----------------------------------------------------------------------===//
125 // Patterns
126 //===----------------------------------------------------------------------===//
127 
128 namespace {
129 /// Merges subview operation with load/transferRead operation.
130 template <typename OpTy>
131 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
132 public:
133   using OpRewritePattern<OpTy>::OpRewritePattern;
134 
135   LogicalResult matchAndRewrite(OpTy loadOp,
136                                 PatternRewriter &rewriter) const override;
137 
138 private:
139   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
140                  ArrayRef<Value> sourceIndices,
141                  PatternRewriter &rewriter) const;
142 };
143 
144 /// Merges subview operation with store/transferWriteOp operation.
145 template <typename OpTy>
146 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
147 public:
148   using OpRewritePattern<OpTy>::OpRewritePattern;
149 
150   LogicalResult matchAndRewrite(OpTy storeOp,
151                                 PatternRewriter &rewriter) const override;
152 
153 private:
154   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
155                  ArrayRef<Value> sourceIndices,
156                  PatternRewriter &rewriter) const;
157 };
158 
159 template <typename LoadOpTy>
160 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
161     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
162     PatternRewriter &rewriter) const {
163   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
164                                         sourceIndices);
165 }
166 
167 template <>
168 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
169     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
170     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
171   // TODO: support 0-d corner case.
172   if (transferReadOp.getTransferRank() == 0)
173     return;
174   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
175       transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
176       sourceIndices,
177       getPermutationMapAttr(rewriter.getContext(), subViewOp,
178                             transferReadOp.permutation_map()),
179       transferReadOp.padding(),
180       /*mask=*/Value(), transferReadOp.in_boundsAttr());
181 }
182 
183 template <typename StoreOpTy>
184 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
185     StoreOpTy storeOp, memref::SubViewOp subViewOp,
186     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
187   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
188                                          subViewOp.source(), sourceIndices);
189 }
190 
191 template <>
192 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
193     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
194     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
195   // TODO: support 0-d corner case.
196   if (transferWriteOp.getTransferRank() == 0)
197     return;
198   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
199       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
200       sourceIndices,
201       getPermutationMapAttr(rewriter.getContext(), subViewOp,
202                             transferWriteOp.permutation_map()),
203       transferWriteOp.in_boundsAttr());
204 }
205 } // namespace
206 
207 template <typename OpTy>
208 LogicalResult
209 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
210                                              PatternRewriter &rewriter) const {
211   auto subViewOp =
212       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
213   if (!subViewOp)
214     return failure();
215 
216   ValueRange indices = loadOp.indices();
217   // For affine ops, we need to apply the map to get the operands to get the
218   // "actual" indices.
219   if (auto affineLoadOp = dyn_cast<AffineLoadOp>(loadOp.getOperation())) {
220     auto expandedIndices =
221         expandAffineMap(rewriter, loadOp.getLoc(), affineLoadOp.getAffineMap(),
222                         affineLoadOp.indices());
223     if (!expandedIndices)
224       return failure();
225     indices = expandedIndices.getValue();
226   }
227   SmallVector<Value, 4> sourceIndices;
228   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, indices,
229                                   sourceIndices)))
230     return failure();
231 
232   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
233   return success();
234 }
235 
236 template <typename OpTy>
237 LogicalResult
238 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
239                                               PatternRewriter &rewriter) const {
240   auto subViewOp =
241       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
242   if (!subViewOp)
243     return failure();
244 
245   ValueRange indices = storeOp.indices();
246   // For affine ops, we need to apply the map to get the operands to get the
247   // "actual" indices.
248   if (auto affineStoreOp = dyn_cast<AffineStoreOp>(storeOp.getOperation())) {
249     auto expandedIndices =
250         expandAffineMap(rewriter, storeOp.getLoc(),
251                         affineStoreOp.getAffineMap(), affineStoreOp.indices());
252     if (!expandedIndices)
253       return failure();
254     indices = expandedIndices.getValue();
255   }
256   SmallVector<Value, 4> sourceIndices;
257   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
258                                   indices, sourceIndices)))
259     return failure();
260 
261   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
262   return success();
263 }
264 
265 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
266   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
267                LoadOpOfSubViewFolder<memref::LoadOp>,
268                LoadOpOfSubViewFolder<vector::TransferReadOp>,
269                StoreOpOfSubViewFolder<AffineStoreOp>,
270                StoreOpOfSubViewFolder<memref::StoreOp>,
271                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
272       patterns.getContext());
273 }
274 
275 //===----------------------------------------------------------------------===//
276 // Pass registration
277 //===----------------------------------------------------------------------===//
278 
279 namespace {
280 
281 #define GEN_PASS_CLASSES
282 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
283 
284 struct FoldSubViewOpsPass final
285     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
286   void runOnOperation() override;
287 };
288 
289 } // namespace
290 
291 void FoldSubViewOpsPass::runOnOperation() {
292   RewritePatternSet patterns(&getContext());
293   memref::populateFoldSubViewOpPatterns(patterns);
294   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
295                                      std::move(patterns));
296 }
297 
298 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
299   return std::make_unique<FoldSubViewOpsPass>();
300 }
301