1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/SmallBitVector.h"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Given the 'indices' of an load/store operation where the memref is a result
31 /// of a subview op, returns the indices w.r.t to the source memref of the
32 /// subview op. For example
33 ///
34 /// %0 = ... : memref<12x42xf32>
35 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
36 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
37 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
38 ///
39 /// could be folded into
40 ///
41 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
42 ///          memref<12x42xf32>
43 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,memref::SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)44 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
45                      memref::SubViewOp subViewOp, ValueRange indices,
46                      SmallVectorImpl<Value> &sourceIndices) {
47   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
48   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
49   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
50 
51   SmallVector<Value> useIndices;
52   // Check if this is rank-reducing case. Then for every unit-dim size add a
53   // zero to the indices.
54   unsigned resultDim = 0;
55   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
56   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
57     if (unusedDims.test(dim))
58       useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
59     else
60       useIndices.push_back(indices[resultDim++]);
61   }
62   if (useIndices.size() != mixedOffsets.size())
63     return failure();
64   sourceIndices.resize(useIndices.size());
65   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
66     SmallVector<Value> dynamicOperands;
67     AffineExpr expr = rewriter.getAffineDimExpr(0);
68     unsigned numSymbols = 0;
69     dynamicOperands.push_back(useIndices[index]);
70 
71     // Multiply the stride;
72     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
73       expr = expr * attr.cast<IntegerAttr>().getInt();
74     } else {
75       dynamicOperands.push_back(mixedStrides[index].get<Value>());
76       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
77     }
78 
79     // Add the offset.
80     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
81       expr = expr + attr.cast<IntegerAttr>().getInt();
82     } else {
83       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
84       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
85     }
86     Location loc = subViewOp.getLoc();
87     sourceIndices[index] = rewriter.create<AffineApplyOp>(
88         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
89   }
90   return success();
91 }
92 
93 /// Helpers to access the memref operand for each op.
94 template <typename LoadOrStoreOpTy>
getMemRefOperand(LoadOrStoreOpTy op)95 static Value getMemRefOperand(LoadOrStoreOpTy op) {
96   return op.getMemref();
97 }
98 
getMemRefOperand(vector::TransferReadOp op)99 static Value getMemRefOperand(vector::TransferReadOp op) {
100   return op.getSource();
101 }
102 
getMemRefOperand(vector::TransferWriteOp op)103 static Value getMemRefOperand(vector::TransferWriteOp op) {
104   return op.getSource();
105 }
106 
107 /// Given the permutation map of the original
108 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
109 /// permutation map to use after the subview is folded with it.
getPermutationMapAttr(MLIRContext * context,memref::SubViewOp subViewOp,AffineMap currPermutationMap)110 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
111                                            memref::SubViewOp subViewOp,
112                                            AffineMap currPermutationMap) {
113   llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
114   SmallVector<AffineExpr> exprs;
115   int64_t sourceRank = subViewOp.getSourceType().getRank();
116   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
117     if (unusedDims.test(dim))
118       continue;
119     exprs.push_back(getAffineDimExpr(dim, context));
120   }
121   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
122   return AffineMapAttr::get(
123       currPermutationMap.compose(resultDimToSourceDimMap));
124 }
125 
126 //===----------------------------------------------------------------------===//
127 // Patterns
128 //===----------------------------------------------------------------------===//
129 
130 namespace {
131 /// Merges subview operation with load/transferRead operation.
132 template <typename OpTy>
133 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
134 public:
135   using OpRewritePattern<OpTy>::OpRewritePattern;
136 
137   LogicalResult matchAndRewrite(OpTy loadOp,
138                                 PatternRewriter &rewriter) const override;
139 
140 private:
141   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
142                  ArrayRef<Value> sourceIndices,
143                  PatternRewriter &rewriter) const;
144 };
145 
146 /// Merges subview operation with store/transferWriteOp operation.
147 template <typename OpTy>
148 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
149 public:
150   using OpRewritePattern<OpTy>::OpRewritePattern;
151 
152   LogicalResult matchAndRewrite(OpTy storeOp,
153                                 PatternRewriter &rewriter) const override;
154 
155 private:
156   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
157                  ArrayRef<Value> sourceIndices,
158                  PatternRewriter &rewriter) const;
159 };
160 
161 template <typename LoadOpTy>
replaceOp(LoadOpTy loadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const162 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
163     LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
164     PatternRewriter &rewriter) const {
165   rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
166                                         sourceIndices);
167 }
168 
169 template <>
replaceOp(vector::TransferReadOp transferReadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const170 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
171     vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
172     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
173   // TODO: support 0-d corner case.
174   if (transferReadOp.getTransferRank() == 0)
175     return;
176   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
177       transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
178       sourceIndices,
179       getPermutationMapAttr(rewriter.getContext(), subViewOp,
180                             transferReadOp.getPermutationMap()),
181       transferReadOp.getPadding(),
182       /*mask=*/Value(), transferReadOp.getInBoundsAttr());
183 }
184 
185 template <typename StoreOpTy>
replaceOp(StoreOpTy storeOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const186 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
187     StoreOpTy storeOp, memref::SubViewOp subViewOp,
188     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
189   rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
190                                          subViewOp.getSource(), sourceIndices);
191 }
192 
193 template <>
replaceOp(vector::TransferWriteOp transferWriteOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const194 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
195     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
196     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
197   // TODO: support 0-d corner case.
198   if (transferWriteOp.getTransferRank() == 0)
199     return;
200   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
201       transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
202       sourceIndices,
203       getPermutationMapAttr(rewriter.getContext(), subViewOp,
204                             transferWriteOp.getPermutationMap()),
205       transferWriteOp.getInBoundsAttr());
206 }
207 } // namespace
208 
209 template <typename OpTy>
210 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const211 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
212                                              PatternRewriter &rewriter) const {
213   auto subViewOp =
214       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
215   if (!subViewOp)
216     return failure();
217 
218   SmallVector<Value, 4> sourceIndices;
219   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
220                                   loadOp.getIndices(), sourceIndices)))
221     return failure();
222 
223   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
224   return success();
225 }
226 
227 template <typename OpTy>
228 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const229 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
230                                               PatternRewriter &rewriter) const {
231   auto subViewOp =
232       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
233   if (!subViewOp)
234     return failure();
235 
236   SmallVector<Value, 4> sourceIndices;
237   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
238                                   storeOp.getIndices(), sourceIndices)))
239     return failure();
240 
241   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
242   return success();
243 }
244 
populateFoldSubViewOpPatterns(RewritePatternSet & patterns)245 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
246   patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
247                LoadOpOfSubViewFolder<memref::LoadOp>,
248                LoadOpOfSubViewFolder<vector::TransferReadOp>,
249                StoreOpOfSubViewFolder<AffineStoreOp>,
250                StoreOpOfSubViewFolder<memref::StoreOp>,
251                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
252       patterns.getContext());
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Pass registration
257 //===----------------------------------------------------------------------===//
258 
259 namespace {
260 
261 struct FoldSubViewOpsPass final
262     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
263   void runOnOperation() override;
264 };
265 
266 } // namespace
267 
runOnOperation()268 void FoldSubViewOpsPass::runOnOperation() {
269   RewritePatternSet patterns(&getContext());
270   memref::populateFoldSubViewOpPatterns(patterns);
271   (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
272 }
273 
createFoldSubViewOpsPass()274 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
275   return std::make_unique<FoldSubViewOpsPass>();
276 }
277