1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/MemRef/IR/MemRef.h" 16 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 17 #include "mlir/Dialect/StandardOps/IR/Ops.h" 18 #include "mlir/Dialect/Vector/VectorOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Utility functions 26 //===----------------------------------------------------------------------===// 27 28 /// Given the 'indices' of an load/store operation where the memref is a result 29 /// of a subview op, returns the indices w.r.t to the source memref of the 30 /// subview op. For example 31 /// 32 /// %0 = ... : memref<12x42xf32> 33 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 34 /// memref<4x4xf32, offset=?, strides=[?, ?]> 35 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 36 /// 37 /// could be folded into 38 /// 39 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 40 /// memref<12x42xf32> 41 static LogicalResult 42 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 43 memref::SubViewOp subViewOp, ValueRange indices, 44 SmallVectorImpl<Value> &sourceIndices) { 45 SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 46 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 47 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 48 49 SmallVector<Value> useIndices; 50 // Check if this is rank-reducing case. Then for every unit-dim size add a 51 // zero to the indices. 52 ArrayRef<int64_t> resultShape = subViewOp.getType().getShape(); 53 unsigned resultDim = 0; 54 for (auto size : llvm::enumerate(mixedSizes)) { 55 auto attr = size.value().dyn_cast<Attribute>(); 56 // Check if this dimension has been dropped, i.e. the size is 1, but the 57 // associated dimension is not 1. 58 if (attr && attr.cast<IntegerAttr>().getInt() == 1 && 59 (resultDim >= resultShape.size() || resultShape[resultDim] != 1)) 60 useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0)); 61 else if (resultDim < resultShape.size()) { 62 useIndices.push_back(indices[resultDim++]); 63 } 64 } 65 if (useIndices.size() != mixedOffsets.size()) 66 return failure(); 67 sourceIndices.resize(useIndices.size()); 68 for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 69 SmallVector<Value> dynamicOperands; 70 AffineExpr expr = rewriter.getAffineDimExpr(0); 71 unsigned numSymbols = 0; 72 dynamicOperands.push_back(useIndices[index]); 73 74 // Multiply the stride; 75 if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 76 expr = expr * attr.cast<IntegerAttr>().getInt(); 77 } else { 78 dynamicOperands.push_back(mixedStrides[index].get<Value>()); 79 expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 80 } 81 82 // Add the offset. 83 if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 84 expr = expr + attr.cast<IntegerAttr>().getInt(); 85 } else { 86 dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 87 expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 88 } 89 Location loc = subViewOp.getLoc(); 90 sourceIndices[index] = rewriter.create<AffineApplyOp>( 91 loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 92 } 93 return success(); 94 } 95 96 /// Helpers to access the memref operand for each op. 97 static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } 98 99 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 100 101 static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } 102 103 static Value getMemRefOperand(vector::TransferWriteOp op) { 104 return op.source(); 105 } 106 107 //===----------------------------------------------------------------------===// 108 // Patterns 109 //===----------------------------------------------------------------------===// 110 111 namespace { 112 /// Merges subview operation with load/transferRead operation. 113 template <typename OpTy> 114 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 115 public: 116 using OpRewritePattern<OpTy>::OpRewritePattern; 117 118 LogicalResult matchAndRewrite(OpTy loadOp, 119 PatternRewriter &rewriter) const override; 120 121 private: 122 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 123 ArrayRef<Value> sourceIndices, 124 PatternRewriter &rewriter) const; 125 }; 126 127 /// Merges subview operation with store/transferWriteOp operation. 128 template <typename OpTy> 129 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 130 public: 131 using OpRewritePattern<OpTy>::OpRewritePattern; 132 133 LogicalResult matchAndRewrite(OpTy storeOp, 134 PatternRewriter &rewriter) const override; 135 136 private: 137 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 138 ArrayRef<Value> sourceIndices, 139 PatternRewriter &rewriter) const; 140 }; 141 142 template <> 143 void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp( 144 memref::LoadOp loadOp, memref::SubViewOp subViewOp, 145 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 146 rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(), 147 sourceIndices); 148 } 149 150 template <> 151 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 152 vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, 153 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 154 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 155 loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, 156 loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr()); 157 } 158 159 template <> 160 void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp( 161 memref::StoreOp storeOp, memref::SubViewOp subViewOp, 162 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 163 rewriter.replaceOpWithNewOp<memref::StoreOp>( 164 storeOp, storeOp.value(), subViewOp.source(), sourceIndices); 165 } 166 167 template <> 168 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 169 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 170 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 171 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 172 transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 173 sourceIndices, transferWriteOp.permutation_map(), 174 transferWriteOp.in_boundsAttr()); 175 } 176 } // namespace 177 178 template <typename OpTy> 179 LogicalResult 180 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 181 PatternRewriter &rewriter) const { 182 auto subViewOp = 183 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 184 if (!subViewOp) 185 return failure(); 186 187 SmallVector<Value, 4> sourceIndices; 188 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 189 loadOp.indices(), sourceIndices))) 190 return failure(); 191 192 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 193 return success(); 194 } 195 196 template <typename OpTy> 197 LogicalResult 198 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 199 PatternRewriter &rewriter) const { 200 auto subViewOp = 201 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 202 if (!subViewOp) 203 return failure(); 204 205 SmallVector<Value, 4> sourceIndices; 206 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 207 storeOp.indices(), sourceIndices))) 208 return failure(); 209 210 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 211 return success(); 212 } 213 214 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 215 patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>, 216 LoadOpOfSubViewFolder<vector::TransferReadOp>, 217 StoreOpOfSubViewFolder<memref::StoreOp>, 218 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 219 patterns.getContext()); 220 } 221 222 //===----------------------------------------------------------------------===// 223 // Pass registration 224 //===----------------------------------------------------------------------===// 225 226 namespace { 227 228 #define GEN_PASS_CLASSES 229 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 230 231 struct FoldSubViewOpsPass final 232 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 233 void runOnOperation() override; 234 }; 235 236 } // namespace 237 238 void FoldSubViewOpsPass::runOnOperation() { 239 RewritePatternSet patterns(&getContext()); 240 memref::populateFoldSubViewOpPatterns(patterns); 241 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 242 std::move(patterns)); 243 } 244 245 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 246 return std::make_unique<FoldSubViewOpsPass>(); 247 } 248