1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Utility functions 25 //===----------------------------------------------------------------------===// 26 27 /// Given the 'indices' of an load/store operation where the memref is a result 28 /// of a subview op, returns the indices w.r.t to the source memref of the 29 /// subview op. For example 30 /// 31 /// %0 = ... : memref<12x42xf32> 32 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 33 /// memref<4x4xf32, offset=?, strides=[?, ?]> 34 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 35 /// 36 /// could be folded into 37 /// 38 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 39 /// memref<12x42xf32> 40 static LogicalResult 41 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 42 memref::SubViewOp subViewOp, ValueRange indices, 43 SmallVectorImpl<Value> &sourceIndices) { 44 // TODO: Aborting when the offsets are static. There might be a way to fold 45 // the subview op with load even if the offsets have been canonicalized 46 // away. 47 SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc); 48 if (opRanges.size() != indices.size()) { 49 // For the rank-reduced cases, we can only handle the folding when the 50 // offset is zero, size is 1 and stride is 1. 51 return failure(); 52 } 53 auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); 54 auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); 55 56 // New indices for the load are the current indices * subview_stride + 57 // subview_offset. 58 sourceIndices.resize(indices.size()); 59 for (auto index : llvm::enumerate(indices)) { 60 auto offset = *(opOffsets.begin() + index.index()); 61 auto stride = *(opStrides.begin() + index.index()); 62 auto mul = rewriter.create<MulIOp>(loc, index.value(), stride); 63 sourceIndices[index.index()] = 64 rewriter.create<AddIOp>(loc, offset, mul).getResult(); 65 } 66 return success(); 67 } 68 69 /// Helpers to access the memref operand for each op. 70 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } 71 72 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 73 74 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } 75 76 static Value getMemRefOperand(vector::TransferWriteOp op) { 77 return op.source(); 78 } 79 80 //===----------------------------------------------------------------------===// 81 // Patterns 82 //===----------------------------------------------------------------------===// 83 84 namespace { 85 /// Merges subview operation with load/transferRead operation. 86 template <typename OpTy> 87 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 88 public: 89 using OpRewritePattern<OpTy>::OpRewritePattern; 90 91 LogicalResult matchAndRewrite(OpTy loadOp, 92 PatternRewriter &rewriter) const override; 93 94 private: 95 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 96 ArrayRef<Value> sourceIndices, 97 PatternRewriter &rewriter) const; 98 }; 99 100 /// Merges subview operation with store/transferWriteOp operation. 101 template <typename OpTy> 102 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 103 public: 104 using OpRewritePattern<OpTy>::OpRewritePattern; 105 106 LogicalResult matchAndRewrite(OpTy storeOp, 107 PatternRewriter &rewriter) const override; 108 109 private: 110 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 111 ArrayRef<Value> sourceIndices, 112 PatternRewriter &rewriter) const; 113 }; 114 115 template <> 116 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp( 117 memref::LoadOp loadOp, memref::SubViewOp subViewOp, 118 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 119 rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(), 120 sourceIndices); 121 } 122 123 template <> 124 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 125 vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, 126 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 127 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 128 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, 129 loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr()); 130 } 131 132 template <> 133 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp( 134 memref::StoreOp storeOp, memref::SubViewOp subViewOp, 135 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 136 rewriter.replaceOpWithNewOp<memref::StoreOp>( 137 storeOp, storeOp.value(), subViewOp.source(), sourceIndices); 138 } 139 140 template <> 141 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 142 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 143 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 144 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 145 transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 146 sourceIndices, transferWriteOp.permutation_map(), 147 transferWriteOp.in_boundsAttr()); 148 } 149 } // namespace 150 151 template <typename OpTy> 152 LogicalResult 153 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 154 PatternRewriter &rewriter) const { 155 auto subViewOp = 156 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 157 if (!subViewOp) 158 return failure(); 159 160 SmallVector<Value, 4> sourceIndices; 161 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 162 loadOp.indices(), sourceIndices))) 163 return failure(); 164 165 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 166 return success(); 167 } 168 169 template <typename OpTy> 170 LogicalResult 171 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 172 PatternRewriter &rewriter) const { 173 auto subViewOp = 174 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 175 if (!subViewOp) 176 return failure(); 177 178 SmallVector<Value, 4> sourceIndices; 179 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 180 storeOp.indices(), sourceIndices))) 181 return failure(); 182 183 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 184 return success(); 185 } 186 187 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 188 patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>, 189 LoadOpOfSubViewFolder<vector::TransferReadOp>, 190 StoreOpOfSubViewFolder<memref::StoreOp>, 191 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 192 patterns.getContext()); 193 } 194 195 //===----------------------------------------------------------------------===// 196 // Pass registration 197 //===----------------------------------------------------------------------===// 198 199 namespace { 200 201 #define GEN_PASS_CLASSES 202 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 203 204 struct FoldSubViewOpsPass final 205 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 206 void runOnOperation() override; 207 }; 208 209 } // namespace 210 211 void FoldSubViewOpsPass::runOnOperation() { 212 RewritePatternSet patterns(&getContext()); 213 memref::populateFoldSubViewOpPatterns(patterns); 214 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 215 std::move(patterns)); 216 } 217 218 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 219 return std::make_unique<FoldSubViewOpsPass>(); 220 } 221