1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Utility functions 26 //===----------------------------------------------------------------------===// 27 28 /// Given the 'indices' of an load/store operation where the memref is a result 29 /// of a subview op, returns the indices w.r.t to the source memref of the 30 /// subview op. For example 31 /// 32 /// %0 = ... : memref<12x42xf32> 33 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 34 /// memref<4x4xf32, offset=?, strides=[?, ?]> 35 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 36 /// 37 /// could be folded into 38 /// 39 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 40 /// memref<12x42xf32> 41 static LogicalResult 42 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 43 memref::SubViewOp subViewOp, ValueRange indices, 44 SmallVectorImpl<Value> &sourceIndices) { 45 SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 46 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 47 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 48 49 SmallVector<Value> useIndices; 50 // Check if this is rank-reducing case. Then for every unit-dim size add a 51 // zero to the indices. 52 unsigned resultDim = 0; 53 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 54 for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) { 55 if (unusedDims.count(dim)) 56 useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0)); 57 else 58 useIndices.push_back(indices[resultDim++]); 59 } 60 if (useIndices.size() != mixedOffsets.size()) 61 return failure(); 62 sourceIndices.resize(useIndices.size()); 63 for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 64 SmallVector<Value> dynamicOperands; 65 AffineExpr expr = rewriter.getAffineDimExpr(0); 66 unsigned numSymbols = 0; 67 dynamicOperands.push_back(useIndices[index]); 68 69 // Multiply the stride; 70 if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 71 expr = expr * attr.cast<IntegerAttr>().getInt(); 72 } else { 73 dynamicOperands.push_back(mixedStrides[index].get<Value>()); 74 expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 75 } 76 77 // Add the offset. 78 if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 79 expr = expr + attr.cast<IntegerAttr>().getInt(); 80 } else { 81 dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 82 expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 83 } 84 Location loc = subViewOp.getLoc(); 85 sourceIndices[index] = rewriter.create<AffineApplyOp>( 86 loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 87 } 88 return success(); 89 } 90 91 /// Helpers to access the memref operand for each op. 92 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } 93 94 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 95 96 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } 97 98 static Value getMemRefOperand(vector::TransferWriteOp op) { 99 return op.source(); 100 } 101 102 /// Given the permutation map of the original 103 /// `vector.transfer_read`/`vector.transfer_write` operations compute the 104 /// permutation map to use after the subview is folded with it. 105 static AffineMap getPermutationMap(MLIRContext *context, 106 memref::SubViewOp subViewOp, 107 AffineMap currPermutationMap) { 108 llvm::SmallDenseSet<unsigned> unusedDims = subViewOp.getDroppedDims(); 109 SmallVector<AffineExpr> exprs; 110 int64_t sourceRank = subViewOp.getSourceType().getRank(); 111 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 112 if (unusedDims.count(dim)) 113 continue; 114 exprs.push_back(getAffineDimExpr(dim, context)); 115 } 116 auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); 117 return currPermutationMap.compose(resultDimToSourceDimMap); 118 } 119 120 //===----------------------------------------------------------------------===// 121 // Patterns 122 //===----------------------------------------------------------------------===// 123 124 namespace { 125 /// Merges subview operation with load/transferRead operation. 126 template <typename OpTy> 127 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 128 public: 129 using OpRewritePattern<OpTy>::OpRewritePattern; 130 131 LogicalResult matchAndRewrite(OpTy loadOp, 132 PatternRewriter &rewriter) const override; 133 134 private: 135 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 136 ArrayRef<Value> sourceIndices, 137 PatternRewriter &rewriter) const; 138 }; 139 140 /// Merges subview operation with store/transferWriteOp operation. 141 template <typename OpTy> 142 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 143 public: 144 using OpRewritePattern<OpTy>::OpRewritePattern; 145 146 LogicalResult matchAndRewrite(OpTy storeOp, 147 PatternRewriter &rewriter) const override; 148 149 private: 150 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 151 ArrayRef<Value> sourceIndices, 152 PatternRewriter &rewriter) const; 153 }; 154 155 template <> 156 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp( 157 memref::LoadOp loadOp, memref::SubViewOp subViewOp, 158 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 159 rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(), 160 sourceIndices); 161 } 162 163 template <> 164 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 165 vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, 166 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 167 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 168 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, 169 getPermutationMap(rewriter.getContext(), subViewOp, 170 loadOp.permutation_map()), 171 loadOp.padding(), loadOp.in_boundsAttr()); 172 } 173 174 template <> 175 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp( 176 memref::StoreOp storeOp, memref::SubViewOp subViewOp, 177 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 178 rewriter.replaceOpWithNewOp<memref::StoreOp>( 179 storeOp, storeOp.value(), subViewOp.source(), sourceIndices); 180 } 181 182 template <> 183 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 184 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 185 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 186 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 187 transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 188 sourceIndices, 189 getPermutationMap(rewriter.getContext(), subViewOp, 190 transferWriteOp.permutation_map()), 191 transferWriteOp.in_boundsAttr()); 192 } 193 } // namespace 194 195 template <typename OpTy> 196 LogicalResult 197 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 198 PatternRewriter &rewriter) const { 199 auto subViewOp = 200 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 201 if (!subViewOp) 202 return failure(); 203 204 SmallVector<Value, 4> sourceIndices; 205 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 206 loadOp.indices(), sourceIndices))) 207 return failure(); 208 209 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 210 return success(); 211 } 212 213 template <typename OpTy> 214 LogicalResult 215 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 216 PatternRewriter &rewriter) const { 217 auto subViewOp = 218 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 219 if (!subViewOp) 220 return failure(); 221 222 SmallVector<Value, 4> sourceIndices; 223 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 224 storeOp.indices(), sourceIndices))) 225 return failure(); 226 227 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 228 return success(); 229 } 230 231 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 232 patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>, 233 LoadOpOfSubViewFolder<vector::TransferReadOp>, 234 StoreOpOfSubViewFolder<memref::StoreOp>, 235 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 236 patterns.getContext()); 237 } 238 239 //===----------------------------------------------------------------------===// 240 // Pass registration 241 //===----------------------------------------------------------------------===// 242 243 namespace { 244 245 #define GEN_PASS_CLASSES 246 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 247 248 struct FoldSubViewOpsPass final 249 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 250 void runOnOperation() override; 251 }; 252 253 } // namespace 254 255 void FoldSubViewOpsPass::runOnOperation() { 256 RewritePatternSet patterns(&getContext()); 257 memref::populateFoldSubViewOpPatterns(patterns); 258 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 259 std::move(patterns)); 260 } 261 262 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 263 return std::make_unique<FoldSubViewOpsPass>(); 264 } 265