1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass folds loading/storing from/to subview ops into
10 // loading/storing from/to the original memref.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
17 #include "mlir/Dialect/MemRef/IR/MemRef.h"
18 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
19 #include "mlir/Dialect/Vector/IR/VectorOps.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22 #include "llvm/ADT/SmallBitVector.h"
23
24 using namespace mlir;
25
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29
30 /// Given the 'indices' of an load/store operation where the memref is a result
31 /// of a subview op, returns the indices w.r.t to the source memref of the
32 /// subview op. For example
33 ///
34 /// %0 = ... : memref<12x42xf32>
35 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
36 /// memref<4x4xf32, offset=?, strides=[?, ?]>
37 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
38 ///
39 /// could be folded into
40 ///
41 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
42 /// memref<12x42xf32>
43 static LogicalResult
resolveSourceIndices(Location loc,PatternRewriter & rewriter,memref::SubViewOp subViewOp,ValueRange indices,SmallVectorImpl<Value> & sourceIndices)44 resolveSourceIndices(Location loc, PatternRewriter &rewriter,
45 memref::SubViewOp subViewOp, ValueRange indices,
46 SmallVectorImpl<Value> &sourceIndices) {
47 SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
48 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
49 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();
50
51 SmallVector<Value> useIndices;
52 // Check if this is rank-reducing case. Then for every unit-dim size add a
53 // zero to the indices.
54 unsigned resultDim = 0;
55 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
56 for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) {
57 if (unusedDims.test(dim))
58 useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
59 else
60 useIndices.push_back(indices[resultDim++]);
61 }
62 if (useIndices.size() != mixedOffsets.size())
63 return failure();
64 sourceIndices.resize(useIndices.size());
65 for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
66 SmallVector<Value> dynamicOperands;
67 AffineExpr expr = rewriter.getAffineDimExpr(0);
68 unsigned numSymbols = 0;
69 dynamicOperands.push_back(useIndices[index]);
70
71 // Multiply the stride;
72 if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
73 expr = expr * attr.cast<IntegerAttr>().getInt();
74 } else {
75 dynamicOperands.push_back(mixedStrides[index].get<Value>());
76 expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
77 }
78
79 // Add the offset.
80 if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
81 expr = expr + attr.cast<IntegerAttr>().getInt();
82 } else {
83 dynamicOperands.push_back(mixedOffsets[index].get<Value>());
84 expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
85 }
86 Location loc = subViewOp.getLoc();
87 sourceIndices[index] = rewriter.create<AffineApplyOp>(
88 loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
89 }
90 return success();
91 }
92
93 /// Helpers to access the memref operand for each op.
94 template <typename LoadOrStoreOpTy>
getMemRefOperand(LoadOrStoreOpTy op)95 static Value getMemRefOperand(LoadOrStoreOpTy op) {
96 return op.getMemref();
97 }
98
getMemRefOperand(vector::TransferReadOp op)99 static Value getMemRefOperand(vector::TransferReadOp op) {
100 return op.getSource();
101 }
102
getMemRefOperand(vector::TransferWriteOp op)103 static Value getMemRefOperand(vector::TransferWriteOp op) {
104 return op.getSource();
105 }
106
107 /// Given the permutation map of the original
108 /// `vector.transfer_read`/`vector.transfer_write` operations compute the
109 /// permutation map to use after the subview is folded with it.
getPermutationMapAttr(MLIRContext * context,memref::SubViewOp subViewOp,AffineMap currPermutationMap)110 static AffineMapAttr getPermutationMapAttr(MLIRContext *context,
111 memref::SubViewOp subViewOp,
112 AffineMap currPermutationMap) {
113 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims();
114 SmallVector<AffineExpr> exprs;
115 int64_t sourceRank = subViewOp.getSourceType().getRank();
116 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) {
117 if (unusedDims.test(dim))
118 continue;
119 exprs.push_back(getAffineDimExpr(dim, context));
120 }
121 auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context);
122 return AffineMapAttr::get(
123 currPermutationMap.compose(resultDimToSourceDimMap));
124 }
125
126 //===----------------------------------------------------------------------===//
127 // Patterns
128 //===----------------------------------------------------------------------===//
129
130 namespace {
131 /// Merges subview operation with load/transferRead operation.
132 template <typename OpTy>
133 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
134 public:
135 using OpRewritePattern<OpTy>::OpRewritePattern;
136
137 LogicalResult matchAndRewrite(OpTy loadOp,
138 PatternRewriter &rewriter) const override;
139
140 private:
141 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
142 ArrayRef<Value> sourceIndices,
143 PatternRewriter &rewriter) const;
144 };
145
146 /// Merges subview operation with store/transferWriteOp operation.
147 template <typename OpTy>
148 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
149 public:
150 using OpRewritePattern<OpTy>::OpRewritePattern;
151
152 LogicalResult matchAndRewrite(OpTy storeOp,
153 PatternRewriter &rewriter) const override;
154
155 private:
156 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
157 ArrayRef<Value> sourceIndices,
158 PatternRewriter &rewriter) const;
159 };
160
161 template <typename LoadOpTy>
replaceOp(LoadOpTy loadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const162 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp(
163 LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices,
164 PatternRewriter &rewriter) const {
165 rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.getSource(),
166 sourceIndices);
167 }
168
169 template <>
replaceOp(vector::TransferReadOp transferReadOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const170 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
171 vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp,
172 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
173 // TODO: support 0-d corner case.
174 if (transferReadOp.getTransferRank() == 0)
175 return;
176 rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
177 transferReadOp, transferReadOp.getVectorType(), subViewOp.getSource(),
178 sourceIndices,
179 getPermutationMapAttr(rewriter.getContext(), subViewOp,
180 transferReadOp.getPermutationMap()),
181 transferReadOp.getPadding(),
182 /*mask=*/Value(), transferReadOp.getInBoundsAttr());
183 }
184
185 template <typename StoreOpTy>
replaceOp(StoreOpTy storeOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const186 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp(
187 StoreOpTy storeOp, memref::SubViewOp subViewOp,
188 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
189 rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.getValue(),
190 subViewOp.getSource(), sourceIndices);
191 }
192
193 template <>
replaceOp(vector::TransferWriteOp transferWriteOp,memref::SubViewOp subViewOp,ArrayRef<Value> sourceIndices,PatternRewriter & rewriter) const194 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
195 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
196 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
197 // TODO: support 0-d corner case.
198 if (transferWriteOp.getTransferRank() == 0)
199 return;
200 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
201 transferWriteOp, transferWriteOp.getVector(), subViewOp.getSource(),
202 sourceIndices,
203 getPermutationMapAttr(rewriter.getContext(), subViewOp,
204 transferWriteOp.getPermutationMap()),
205 transferWriteOp.getInBoundsAttr());
206 }
207 } // namespace
208
209 template <typename OpTy>
210 LogicalResult
matchAndRewrite(OpTy loadOp,PatternRewriter & rewriter) const211 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
212 PatternRewriter &rewriter) const {
213 auto subViewOp =
214 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
215 if (!subViewOp)
216 return failure();
217
218 SmallVector<Value, 4> sourceIndices;
219 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
220 loadOp.getIndices(), sourceIndices)))
221 return failure();
222
223 replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
224 return success();
225 }
226
227 template <typename OpTy>
228 LogicalResult
matchAndRewrite(OpTy storeOp,PatternRewriter & rewriter) const229 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
230 PatternRewriter &rewriter) const {
231 auto subViewOp =
232 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
233 if (!subViewOp)
234 return failure();
235
236 SmallVector<Value, 4> sourceIndices;
237 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
238 storeOp.getIndices(), sourceIndices)))
239 return failure();
240
241 replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
242 return success();
243 }
244
populateFoldSubViewOpPatterns(RewritePatternSet & patterns)245 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
246 patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>,
247 LoadOpOfSubViewFolder<memref::LoadOp>,
248 LoadOpOfSubViewFolder<vector::TransferReadOp>,
249 StoreOpOfSubViewFolder<AffineStoreOp>,
250 StoreOpOfSubViewFolder<memref::StoreOp>,
251 StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
252 patterns.getContext());
253 }
254
255 //===----------------------------------------------------------------------===//
256 // Pass registration
257 //===----------------------------------------------------------------------===//
258
259 namespace {
260
261 struct FoldSubViewOpsPass final
262 : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
263 void runOnOperation() override;
264 };
265
266 } // namespace
267
runOnOperation()268 void FoldSubViewOpsPass::runOnOperation() {
269 RewritePatternSet patterns(&getContext());
270 memref::populateFoldSubViewOpPatterns(patterns);
271 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
272 }
273
createFoldSubViewOpsPass()274 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
275 return std::make_unique<FoldSubViewOpsPass>();
276 }
277