1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Given the 'indices' of an load/store operation where the memref is a result
29 /// of a subview op, returns the indices w.r.t to the source memref of the
30 /// subview op. For example
31 ///
32 /// %0 = ... : memref<12x42xf32>
33 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
34 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
35 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
36 ///
37 /// could be folded into
38 ///
39 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
40 ///          memref<12x42xf32>
41 static LogicalResult
42 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
43                      memref::SubViewOp subViewOp, ValueRange indices,
44                      SmallVectorImpl<Value> &sourceIndices) {
45   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
46   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
47   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
48 
49   SmallVector<Value> useIndices;
50   // Check if this is rank-reducing case. Then for every unit-dim size add a
51   // zero to the indices.
52   unsigned resultDim = 0;
53   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
54   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
55     if (unusedDims.count(dim))
56       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
57     else
58       useIndices.push_back(indices[resultDim++]);
59   }
60   if (useIndices.size() != mixedOffsets.size())
61     return failure();
62   sourceIndices.resize(useIndices.size());
63   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
64     SmallVector<Value> dynamicOperands;
65     AffineExpr expr = rewriter.getAffineDimExpr(0);
66     unsigned numSymbols = 0;
67     dynamicOperands.push_back(useIndices[index]);
68 
69     // Multiply the stride;
70     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
71       expr = expr * attr.cast<IntegerAttr>().getInt();
72     } else {
73       dynamicOperands.push_back(mixedStrides[index].get<Value>());
74       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
75     }
76 
77     // Add the offset.
78     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
79       expr = expr + attr.cast<IntegerAttr>().getInt();
80     } else {
81       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
82       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
83     }
84     Location loc = subViewOp.getLoc();
85     sourceIndices[index] = rewriter.create<AffineApplyOp>(
86         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
87   }
88   return success();
89 }
90 
91 /// Helpers to access the memref operand for each op.
92 template <typename LoadOrStoreOpTy>
93 static Value getMemRefOperand(LoadOrStoreOpTy op) {
94   return op.memref();
95 }
96 
97 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
98 
99 static Value getMemRefOperand(vector::TransferWriteOp op) {
100   return op.source();
101 }
102 
103 /// Given the permutation map of the original
104 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
105 /// permutation map to use after the subview is folded with it.
106 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
107                                            memref::SubViewOp subViewOp,
108                                            AffineMap currPermutationMap) {
109   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
110   SmallVector<AffineExpr> exprs;
111   int64_t sourceRank = subViewOp.getSourceType().getRank();
112   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
113     if (unusedDims.count(dim))
114       continue;
115     exprs.push_back(getAffineDimExpr(dim, context));
116   }
117   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
118   return AffineMapAttr::get(
119       currPermutationMap.compose(resultDimToSourceDimMap));
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // Patterns
124 //===----------------------------------------------------------------------===//
125 
126 namespace {
127 /// Merges subview operation with load/transferRead operation.
128 template <typename OpTy>
129 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
130 public:
131   using OpRewritePattern<OpTy>::OpRewritePattern;
132 
133   LogicalResult matchAndRewrite(OpTy loadOp,
134                                 PatternRewriter &rewriter) const override;
135 
136 private:
137   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
138                  ArrayRef<Value> sourceIndices,
139                  PatternRewriter &rewriter) const;
140 };
141 
142 /// Merges subview operation with store/transferWriteOp operation.
143 template <typename OpTy>
144 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
145 public:
146   using OpRewritePattern<OpTy>::OpRewritePattern;
147 
148   LogicalResult matchAndRewrite(OpTy storeOp,
149                                 PatternRewriter &rewriter) const override;
150 
151 private:
152   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
153                  ArrayRef<Value> sourceIndices,
154                  PatternRewriter &rewriter) const;
155 };
156 
157 template <typename LoadOpTy>
158 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
159     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
160     PatternRewriter &rewriter) const {
161   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
162                                         sourceIndices);
163 }
164 
165 template <>
166 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
167     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
168     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
169   // TODO: support 0-d corner case.
170   if (transferReadOp.getTransferRank() == 0)
171     return;
172   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
173       transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
174       sourceIndices,
175       getPermutationMapAttr(rewriter.getContext(), subViewOp,
176                             transferReadOp.permutation_map()),
177       transferReadOp.padding(),
178       /*mask=*/Value(), transferReadOp.in_boundsAttr());
179 }
180 
181 template <typename StoreOpTy>
182 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
183     StoreOpTy storeOp, memref::SubViewOp subViewOp,
184     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
185   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
186                                          subViewOp.source(), sourceIndices);
187 }
188 
189 template <>
190 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
191     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
192     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
193   // TODO: support 0-d corner case.
194   if (transferWriteOp.getTransferRank() == 0)
195     return;
196   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
197       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
198       sourceIndices,
199       getPermutationMapAttr(rewriter.getContext(), subViewOp,
200                             transferWriteOp.permutation_map()),
201       transferWriteOp.in_boundsAttr());
202 }
203 } // namespace
204 
205 template <typename OpTy>
206 LogicalResult
207 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
208                                              PatternRewriter &rewriter) const {
209   auto subViewOp =
210       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
211   if (!subViewOp)
212     return failure();
213 
214   SmallVector<Value, 4> sourceIndices;
215   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
216                                   loadOp.indices(), sourceIndices)))
217     return failure();
218 
219   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
220   return success();
221 }
222 
223 template <typename OpTy>
224 LogicalResult
225 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
226                                               PatternRewriter &rewriter) const {
227   auto subViewOp =
228       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
229   if (!subViewOp)
230     return failure();
231 
232   SmallVector<Value, 4> sourceIndices;
233   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
234                                   storeOp.indices(), sourceIndices)))
235     return failure();
236 
237   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
238   return success();
239 }
240 
241 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
242   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
243                LoadOpOfSubViewFolder<memref::LoadOp>,
244                LoadOpOfSubViewFolder<vector::TransferReadOp>,
245                StoreOpOfSubViewFolder<AffineStoreOp>,
246                StoreOpOfSubViewFolder<memref::StoreOp>,
247                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
248       patterns.getContext());
249 }
250 
251 //===----------------------------------------------------------------------===//
252 // Pass registration
253 //===----------------------------------------------------------------------===//
254 
255 namespace {
256 
257 #define GEN_PASS_CLASSES
258 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
259 
260 struct FoldSubViewOpsPass final
261     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
262   void runOnOperation() override;
263 };
264 
265 } // namespace
266 
267 void FoldSubViewOpsPass::runOnOperation() {
268   RewritePatternSet patterns(&getContext());
269   memref::populateFoldSubViewOpPatterns(patterns);
270   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
271                                      std::move(patterns));
272 }
273 
274 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
275   return std::make_unique<FoldSubViewOpsPass>();
276 }
277