1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/MemRef/IR/MemRef.h" 15 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/Dialect/Vector/VectorOps.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 20 21 using namespace mlir; 22 23 //===----------------------------------------------------------------------===// 24 // Utility functions 25 //===----------------------------------------------------------------------===// 26 27 /// Given the 'indices' of an load/store operation where the memref is a result 28 /// of a subview op, returns the indices w.r.t to the source memref of the 29 /// subview op. For example 30 /// 31 /// %0 = ... : memref<12x42xf32> 32 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 33 /// memref<4x4xf32, offset=?, strides=[?, ?]> 34 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 35 /// 36 /// could be folded into 37 /// 38 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 39 /// memref<12x42xf32> 40 static LogicalResult 41 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 42 memref::SubViewOp subViewOp, ValueRange indices, 43 SmallVectorImpl<Value> &sourceIndices) { 44 // TODO: Aborting when the offsets are static. There might be a way to fold 45 // the subview op with load even if the offsets have been canonicalized 46 // away. 47 SmallVector<Range, 4> opRanges = subViewOp.getOrCreateRanges(rewriter, loc); 48 auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); 49 auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); 50 assert(opRanges.size() == indices.size() && 51 "expected as many indices as rank of subview op result type"); 52 53 // New indices for the load are the current indices * subview_stride + 54 // subview_offset. 55 sourceIndices.resize(indices.size()); 56 for (auto index : llvm::enumerate(indices)) { 57 auto offset = *(opOffsets.begin() + index.index()); 58 auto stride = *(opStrides.begin() + index.index()); 59 auto mul = rewriter.create<MulIOp>(loc, index.value(), stride); 60 sourceIndices[index.index()] = 61 rewriter.create<AddIOp>(loc, offset, mul).getResult(); 62 } 63 return success(); 64 } 65 66 /// Helpers to access the memref operand for each op. 67 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } 68 69 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 70 71 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } 72 73 static Value getMemRefOperand(vector::TransferWriteOp op) { 74 return op.source(); 75 } 76 77 //===----------------------------------------------------------------------===// 78 // Patterns 79 //===----------------------------------------------------------------------===// 80 81 namespace { 82 /// Merges subview operation with load/transferRead operation. 83 template <typename OpTy> 84 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 85 public: 86 using OpRewritePattern<OpTy>::OpRewritePattern; 87 88 LogicalResult matchAndRewrite(OpTy loadOp, 89 PatternRewriter &rewriter) const override; 90 91 private: 92 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 93 ArrayRef<Value> sourceIndices, 94 PatternRewriter &rewriter) const; 95 }; 96 97 /// Merges subview operation with store/transferWriteOp operation. 98 template <typename OpTy> 99 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 100 public: 101 using OpRewritePattern<OpTy>::OpRewritePattern; 102 103 LogicalResult matchAndRewrite(OpTy storeOp, 104 PatternRewriter &rewriter) const override; 105 106 private: 107 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 108 ArrayRef<Value> sourceIndices, 109 PatternRewriter &rewriter) const; 110 }; 111 112 template <> 113 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp( 114 memref::LoadOp loadOp, memref::SubViewOp subViewOp, 115 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 116 rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(), 117 sourceIndices); 118 } 119 120 template <> 121 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 122 vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, 123 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 124 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 125 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, 126 loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr()); 127 } 128 129 template <> 130 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp( 131 memref::StoreOp storeOp, memref::SubViewOp subViewOp, 132 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 133 rewriter.replaceOpWithNewOp<memref::StoreOp>( 134 storeOp, storeOp.value(), subViewOp.source(), sourceIndices); 135 } 136 137 template <> 138 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 139 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 140 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 141 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 142 transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 143 sourceIndices, transferWriteOp.permutation_map(), 144 transferWriteOp.in_boundsAttr()); 145 } 146 } // namespace 147 148 template <typename OpTy> 149 LogicalResult 150 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 151 PatternRewriter &rewriter) const { 152 auto subViewOp = 153 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 154 if (!subViewOp) 155 return failure(); 156 157 SmallVector<Value, 4> sourceIndices; 158 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 159 loadOp.indices(), sourceIndices))) 160 return failure(); 161 162 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 163 return success(); 164 } 165 166 template <typename OpTy> 167 LogicalResult 168 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 169 PatternRewriter &rewriter) const { 170 auto subViewOp = 171 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 172 if (!subViewOp) 173 return failure(); 174 175 SmallVector<Value, 4> sourceIndices; 176 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 177 storeOp.indices(), sourceIndices))) 178 return failure(); 179 180 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 181 return success(); 182 } 183 184 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 185 patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>, 186 LoadOpOfSubViewFolder<vector::TransferReadOp>, 187 StoreOpOfSubViewFolder<memref::StoreOp>, 188 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 189 patterns.getContext()); 190 } 191 192 //===----------------------------------------------------------------------===// 193 // Pass registration 194 //===----------------------------------------------------------------------===// 195 196 namespace { 197 198 #define GEN_PASS_CLASSES 199 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 200 201 struct FoldSubViewOpsPass final 202 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 203 void runOnOperation() override; 204 }; 205 206 } // namespace 207 208 void FoldSubViewOpsPass::runOnOperation() { 209 RewritePatternSet patterns(&getContext()); 210 memref::populateFoldSubViewOpPatterns(patterns); 211 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 212 std::move(patterns)); 213 } 214 215 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 216 return std::make_unique<FoldSubViewOpsPass>(); 217 } 218