1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
18 #include "mlir/Dialect/Vector/IR/VectorOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Utility functions
27 //===----------------------------------------------------------------------===//
28 
29 /// Given the 'indices' of an load/store operation where the memref is a result
30 /// of a subview op, returns the indices w.r.t to the source memref of the
31 /// subview op. For example
32 ///
33 /// %0 = ... : memref<12x42xf32>
34 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
35 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
36 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
37 ///
38 /// could be folded into
39 ///
40 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
41 ///          memref<12x42xf32>
42 static LogicalResult
43 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
44                      memref::SubViewOp subViewOp, ValueRange indices,
45                      SmallVectorImpl<Value> &sourceIndices) {
46   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
47   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
48   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
49 
50   SmallVector<Value> useIndices;
51   // Check if this is rank-reducing case. Then for every unit-dim size add a
52   // zero to the indices.
53   unsigned resultDim = 0;
54   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
55   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
56     if (unusedDims.test(dim))
57       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
58     else
59       useIndices.push_back(indices[resultDim++]);
60   }
61   if (useIndices.size() != mixedOffsets.size())
62     return failure();
63   sourceIndices.resize(useIndices.size());
64   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
65     SmallVector<Value> dynamicOperands;
66     AffineExpr expr = rewriter.getAffineDimExpr(0);
67     unsigned numSymbols = 0;
68     dynamicOperands.push_back(useIndices[index]);
69 
70     // Multiply the stride;
71     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
72       expr = expr * attr.cast<IntegerAttr>().getInt();
73     } else {
74       dynamicOperands.push_back(mixedStrides[index].get<Value>());
75       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
76     }
77 
78     // Add the offset.
79     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
80       expr = expr + attr.cast<IntegerAttr>().getInt();
81     } else {
82       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
83       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
84     }
85     Location loc = subViewOp.getLoc();
86     sourceIndices[index] = rewriter.create<AffineApplyOp>(
87         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
88   }
89   return success();
90 }
91 
92 /// Helpers to access the memref operand for each op.
93 template <typename LoadOrStoreOpTy>
94 static Value getMemRefOperand(LoadOrStoreOpTy op) {
95   return op.memref();
96 }
97 
98 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
99 
100 static Value getMemRefOperand(vector::TransferWriteOp op) {
101   return op.source();
102 }
103 
104 /// Given the permutation map of the original
105 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
106 /// permutation map to use after the subview is folded with it.
107 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
108                                            memref::SubViewOp subViewOp,
109                                            AffineMap currPermutationMap) {
110   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
111   SmallVector<AffineExpr> exprs;
112   int64_t sourceRank = subViewOp.getSourceType().getRank();
113   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
114     if (unusedDims.test(dim))
115       continue;
116     exprs.push_back(getAffineDimExpr(dim, context));
117   }
118   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
119   return AffineMapAttr::get(
120       currPermutationMap.compose(resultDimToSourceDimMap));
121 }
122 
123 //===----------------------------------------------------------------------===//
124 // Patterns
125 //===----------------------------------------------------------------------===//
126 
127 namespace {
128 /// Merges subview operation with load/transferRead operation.
129 template <typename OpTy>
130 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
131 public:
132   using OpRewritePattern<OpTy>::OpRewritePattern;
133 
134   LogicalResult matchAndRewrite(OpTy loadOp,
135                                 PatternRewriter &rewriter) const override;
136 
137 private:
138   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
139                  ArrayRef<Value> sourceIndices,
140                  PatternRewriter &rewriter) const;
141 };
142 
143 /// Merges subview operation with store/transferWriteOp operation.
144 template <typename OpTy>
145 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
146 public:
147   using OpRewritePattern<OpTy>::OpRewritePattern;
148 
149   LogicalResult matchAndRewrite(OpTy storeOp,
150                                 PatternRewriter &rewriter) const override;
151 
152 private:
153   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
154                  ArrayRef<Value> sourceIndices,
155                  PatternRewriter &rewriter) const;
156 };
157 
158 template <typename LoadOpTy>
159 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
160     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
161     PatternRewriter &rewriter) const {
162   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(),
163                                         sourceIndices);
164 }
165 
166 template <>
167 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
168     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
169     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
170   // TODO: support 0-d corner case.
171   if (transferReadOp.getTransferRank() == 0)
172     return;
173   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
174       transferReadOp, transferReadOp.getVectorType(), subViewOp.source(),
175       sourceIndices,
176       getPermutationMapAttr(rewriter.getContext(), subViewOp,
177                             transferReadOp.permutation_map()),
178       transferReadOp.padding(),
179       /*mask=*/Value(), transferReadOp.in_boundsAttr());
180 }
181 
182 template <typename StoreOpTy>
183 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
184     StoreOpTy storeOp, memref::SubViewOp subViewOp,
185     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
186   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(),
187                                          subViewOp.source(), sourceIndices);
188 }
189 
190 template <>
191 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
192     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
193     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
194   // TODO: support 0-d corner case.
195   if (transferWriteOp.getTransferRank() == 0)
196     return;
197   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
198       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
199       sourceIndices,
200       getPermutationMapAttr(rewriter.getContext(), subViewOp,
201                             transferWriteOp.permutation_map()),
202       transferWriteOp.in_boundsAttr());
203 }
204 } // namespace
205 
206 template <typename OpTy>
207 LogicalResult
208 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
209                                              PatternRewriter &rewriter) const {
210   auto subViewOp =
211       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
212   if (!subViewOp)
213     return failure();
214 
215   SmallVector<Value, 4> sourceIndices;
216   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
217                                   loadOp.indices(), sourceIndices)))
218     return failure();
219 
220   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
221   return success();
222 }
223 
224 template <typename OpTy>
225 LogicalResult
226 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
227                                               PatternRewriter &rewriter) const {
228   auto subViewOp =
229       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
230   if (!subViewOp)
231     return failure();
232 
233   SmallVector<Value, 4> sourceIndices;
234   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
235                                   storeOp.indices(), sourceIndices)))
236     return failure();
237 
238   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
239   return success();
240 }
241 
242 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
243   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
244                LoadOpOfSubViewFolder<memref::LoadOp>,
245                LoadOpOfSubViewFolder<vector::TransferReadOp>,
246                StoreOpOfSubViewFolder<AffineStoreOp>,
247                StoreOpOfSubViewFolder<memref::StoreOp>,
248                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
249       patterns.getContext());
250 }
251 
252 //===----------------------------------------------------------------------===//
253 // Pass registration
254 //===----------------------------------------------------------------------===//
255 
256 namespace {
257 
258 #define GEN_PASS_CLASSES
259 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
260 
261 struct FoldSubViewOpsPass final
262     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
263   void runOnOperation() override;
264 };
265 
266 } // namespace
267 
268 void FoldSubViewOpsPass::runOnOperation() {
269   RewritePatternSet patterns(&getContext());
270   memref::populateFoldSubViewOpPatterns(patterns);
271   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
272                                      std::move(patterns));
273 }
274 
275 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
276   return std::make_unique<FoldSubViewOpsPass>();
277 }
278