1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/MemRef/IR/MemRef.h"
16 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/Dialect/Vector/VectorOps.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Utility functions
26 //===----------------------------------------------------------------------===//
27 
28 /// Given the 'indices' of an load/store operation where the memref is a result
29 /// of a subview op, returns the indices w.r.t to the source memref of the
30 /// subview op. For example
31 ///
32 /// %0 = ... : memref<12x42xf32>
33 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
34 ///          memref<4x4xf32, offset=?, strides=[?, ?]>
35 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
36 ///
37 /// could be folded into
38 ///
39 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
40 ///          memref<12x42xf32>
41 static LogicalResult
42 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
43                      memref::SubViewOp subViewOp, ValueRange indices,
44                      SmallVectorImpl<Value> &sourceIndices) {
45   SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
46   SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
47   SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
48 
49   SmallVector<Value> useIndices;
50   // Check if this is rank-reducing case. Then for every unit-dim size add a
51   // zero to the indices.
52   unsigned resultDim = 0;
53   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
54   for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
55     if (unusedDims.count(dim))
56       useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
57     else
58       useIndices.push_back(indices[resultDim++]);
59   }
60   if (useIndices.size() != mixedOffsets.size())
61     return failure();
62   sourceIndices.resize(useIndices.size());
63   for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
64     SmallVector<Value> dynamicOperands;
65     AffineExpr expr = rewriter.getAffineDimExpr(0);
66     unsigned numSymbols = 0;
67     dynamicOperands.push_back(useIndices[index]);
68 
69     // Multiply the stride;
70     if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
71       expr = expr * attr.cast<IntegerAttr>().getInt();
72     } else {
73       dynamicOperands.push_back(mixedStrides[index].get<Value>());
74       expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
75     }
76 
77     // Add the offset.
78     if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
79       expr = expr + attr.cast<IntegerAttr>().getInt();
80     } else {
81       dynamicOperands.push_back(mixedOffsets[index].get<Value>());
82       expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
83     }
84     Location loc = subViewOp.getLoc();
85     sourceIndices[index] = rewriter.create<AffineApplyOp>(
86         loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
87   }
88   return success();
89 }
90 
91 /// Helpers to access the memref operand for each op.
92 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }
93 
94 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }
95 
96 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }
97 
98 static Value getMemRefOperand(vector::TransferWriteOp op) {
99   return op.source();
100 }
101 
102 /// Given the permutation map of the original
103 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
104 /// permutation map to use after the subview is folded with it.
105 static AffineMap getPermutationMap(MLIRContext *context,
106                                    memref::SubViewOp subViewOp,
107                                    AffineMap currPermutationMap) {
108   llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims();
109   SmallVector<AffineExpr> exprs;
110   int64_t sourceRank = subViewOp.getSourceType().getRank();
111   for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
112     if (unusedDims.count(dim))
113       continue;
114     exprs.push_back(getAffineDimExpr(dim, context));
115   }
116   auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
117   return currPermutationMap.compose(resultDimToSourceDimMap);
118 }
119 
120 //===----------------------------------------------------------------------===//
121 // Patterns
122 //===----------------------------------------------------------------------===//
123 
124 namespace {
125 /// Merges subview operation with load/transferRead operation.
126 template <typename OpTy>
127 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
128 public:
129   using OpRewritePattern<OpTy>::OpRewritePattern;
130 
131   LogicalResult matchAndRewrite(OpTy loadOp,
132                                 PatternRewriter &rewriter) const override;
133 
134 private:
135   void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
136                  ArrayRef<Value> sourceIndices,
137                  PatternRewriter &rewriter) const;
138 };
139 
140 /// Merges subview operation with store/transferWriteOp operation.
141 template <typename OpTy>
142 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
143 public:
144   using OpRewritePattern<OpTy>::OpRewritePattern;
145 
146   LogicalResult matchAndRewrite(OpTy storeOp,
147                                 PatternRewriter &rewriter) const override;
148 
149 private:
150   void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
151                  ArrayRef<Value> sourceIndices,
152                  PatternRewriter &rewriter) const;
153 };
154 
155 template <>
156 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
157     memref::LoadOp loadOp, memref::SubViewOp subViewOp,
158     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
159   rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
160                                               sourceIndices);
161 }
162 
163 template <>
164 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
165     vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
166     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
167   rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
168       loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
169       getPermutationMap(rewriter.getContext(), subViewOp,
170                         loadOp.permutation_map()),
171       loadOp.padding(), loadOp.in_boundsAttr());
172 }
173 
174 template <>
175 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
176     memref::StoreOp storeOp, memref::SubViewOp subViewOp,
177     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
178   rewriter.replaceOpWithNewOp<memref::StoreOp>(
179       storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
180 }
181 
182 template <>
183 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
184     vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
185     ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
186   rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
187       transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
188       sourceIndices,
189       getPermutationMap(rewriter.getContext(), subViewOp,
190                         transferWriteOp.permutation_map()),
191       transferWriteOp.in_boundsAttr());
192 }
193 } // namespace
194 
195 template <typename OpTy>
196 LogicalResult
197 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
198                                              PatternRewriter &rewriter) const {
199   auto subViewOp =
200       getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
201   if (!subViewOp)
202     return failure();
203 
204   SmallVector<Value, 4> sourceIndices;
205   if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
206                                   loadOp.indices(), sourceIndices)))
207     return failure();
208 
209   replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
210   return success();
211 }
212 
213 template <typename OpTy>
214 LogicalResult
215 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
216                                               PatternRewriter &rewriter) const {
217   auto subViewOp =
218       getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
219   if (!subViewOp)
220     return failure();
221 
222   SmallVector<Value, 4> sourceIndices;
223   if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
224                                   storeOp.indices(), sourceIndices)))
225     return failure();
226 
227   replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
228   return success();
229 }
230 
231 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
232   patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
233                LoadOpOfSubViewFolder<vector::TransferReadOp>,
234                StoreOpOfSubViewFolder<memref::StoreOp>,
235                StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
236       patterns.getContext());
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // Pass registration
241 //===----------------------------------------------------------------------===//
242 
243 namespace {
244 
245 #define GEN_PASS_CLASSES
246 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"
247 
248 struct FoldSubViewOpsPass final
249     : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
250   void runOnOperation() override;
251 };
252 
253 } // namespace
254 
255 void FoldSubViewOpsPass::runOnOperation() {
256   RewritePatternSet patterns(&getContext());
257   memref::populateFoldSubViewOpPatterns(patterns);
258   (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
259                                      std::move(patterns));
260 }
261 
262 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
263   return std::make_unique<FoldSubViewOpsPass>();
264 }
265