1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/Affine/IR/AffineOps.h" 15 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 16 #include "mlir/Dialect/MemRef/IR/MemRef.h" 17 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 18 #include "mlir/Dialect/Vector/IR/VectorOps.h" 19 #include "mlir/IR/BuiltinTypes.h" 20 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 21 #include "llvm/ADT/SmallBitVector.h" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Utility functions 27 //===----------------------------------------------------------------------===// 28 29 /// Given the 'indices' of an load/store operation where the memref is a result 30 /// of a subview op, returns the indices w.r.t to the source memref of the 31 /// subview op. For example 32 /// 33 /// %0 = ... : memref<12x42xf32> 34 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 35 /// memref<4x4xf32, offset=?, strides=[?, ?]> 36 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 37 /// 38 /// could be folded into 39 /// 40 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 41 /// memref<12x42xf32> 42 static LogicalResult 43 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 44 memref::SubViewOp subViewOp, ValueRange indices, 45 SmallVectorImpl<Value> &sourceIndices) { 46 SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 47 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 48 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 49 50 SmallVector<Value> useIndices; 51 // Check if this is rank-reducing case. Then for every unit-dim size add a 52 // zero to the indices. 53 unsigned resultDim = 0; 54 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 55 for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) { 56 if (unusedDims.test(dim)) 57 useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0)); 58 else 59 useIndices.push_back(indices[resultDim++]); 60 } 61 if (useIndices.size() != mixedOffsets.size()) 62 return failure(); 63 sourceIndices.resize(useIndices.size()); 64 for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 65 SmallVector<Value> dynamicOperands; 66 AffineExpr expr = rewriter.getAffineDimExpr(0); 67 unsigned numSymbols = 0; 68 dynamicOperands.push_back(useIndices[index]); 69 70 // Multiply the stride; 71 if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 72 expr = expr * attr.cast<IntegerAttr>().getInt(); 73 } else { 74 dynamicOperands.push_back(mixedStrides[index].get<Value>()); 75 expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 76 } 77 78 // Add the offset. 79 if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 80 expr = expr + attr.cast<IntegerAttr>().getInt(); 81 } else { 82 dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 83 expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 84 } 85 Location loc = subViewOp.getLoc(); 86 sourceIndices[index] = rewriter.create<AffineApplyOp>( 87 loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 88 } 89 return success(); 90 } 91 92 /// Helpers to access the memref operand for each op. 93 template <typename LoadOrStoreOpTy> 94 static Value getMemRefOperand(LoadOrStoreOpTy op) { 95 return op.memref(); 96 } 97 98 static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } 99 100 static Value getMemRefOperand(vector::TransferWriteOp op) { 101 return op.source(); 102 } 103 104 /// Given the permutation map of the original 105 /// `vector.transfer_read`/`vector.transfer_write` operations compute the 106 /// permutation map to use after the subview is folded with it. 107 static AffineMapAttr getPermutationMapAttr(MLIRContext *context, 108 memref::SubViewOp subViewOp, 109 AffineMap currPermutationMap) { 110 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 111 SmallVector<AffineExpr> exprs; 112 int64_t sourceRank = subViewOp.getSourceType().getRank(); 113 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 114 if (unusedDims.test(dim)) 115 continue; 116 exprs.push_back(getAffineDimExpr(dim, context)); 117 } 118 auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); 119 return AffineMapAttr::get( 120 currPermutationMap.compose(resultDimToSourceDimMap)); 121 } 122 123 //===----------------------------------------------------------------------===// 124 // Patterns 125 //===----------------------------------------------------------------------===// 126 127 namespace { 128 /// Merges subview operation with load/transferRead operation. 129 template <typename OpTy> 130 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 131 public: 132 using OpRewritePattern<OpTy>::OpRewritePattern; 133 134 LogicalResult matchAndRewrite(OpTy loadOp, 135 PatternRewriter &rewriter) const override; 136 137 private: 138 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 139 ArrayRef<Value> sourceIndices, 140 PatternRewriter &rewriter) const; 141 }; 142 143 /// Merges subview operation with store/transferWriteOp operation. 144 template <typename OpTy> 145 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 146 public: 147 using OpRewritePattern<OpTy>::OpRewritePattern; 148 149 LogicalResult matchAndRewrite(OpTy storeOp, 150 PatternRewriter &rewriter) const override; 151 152 private: 153 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 154 ArrayRef<Value> sourceIndices, 155 PatternRewriter &rewriter) const; 156 }; 157 158 template <typename LoadOpTy> 159 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp( 160 LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices, 161 PatternRewriter &rewriter) const { 162 rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(), 163 sourceIndices); 164 } 165 166 template <> 167 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 168 vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, 169 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 170 // TODO: support 0-d corner case. 171 if (transferReadOp.getTransferRank() == 0) 172 return; 173 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 174 transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), 175 sourceIndices, 176 getPermutationMapAttr(rewriter.getContext(), subViewOp, 177 transferReadOp.permutation_map()), 178 transferReadOp.padding(), 179 /*mask=*/Value(), transferReadOp.in_boundsAttr()); 180 } 181 182 template <typename StoreOpTy> 183 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp( 184 StoreOpTy storeOp, memref::SubViewOp subViewOp, 185 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 186 rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(), 187 subViewOp.source(), sourceIndices); 188 } 189 190 template <> 191 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 192 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 193 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 194 // TODO: support 0-d corner case. 195 if (transferWriteOp.getTransferRank() == 0) 196 return; 197 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 198 transferWriteOp, transferWriteOp.vector(), subViewOp.source(), 199 sourceIndices, 200 getPermutationMapAttr(rewriter.getContext(), subViewOp, 201 transferWriteOp.permutation_map()), 202 transferWriteOp.in_boundsAttr()); 203 } 204 } // namespace 205 206 template <typename OpTy> 207 LogicalResult 208 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 209 PatternRewriter &rewriter) const { 210 auto subViewOp = 211 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 212 if (!subViewOp) 213 return failure(); 214 215 SmallVector<Value, 4> sourceIndices; 216 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 217 loadOp.indices(), sourceIndices))) 218 return failure(); 219 220 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 221 return success(); 222 } 223 224 template <typename OpTy> 225 LogicalResult 226 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 227 PatternRewriter &rewriter) const { 228 auto subViewOp = 229 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 230 if (!subViewOp) 231 return failure(); 232 233 SmallVector<Value, 4> sourceIndices; 234 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 235 storeOp.indices(), sourceIndices))) 236 return failure(); 237 238 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 239 return success(); 240 } 241 242 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 243 patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>, 244 LoadOpOfSubViewFolder<memref::LoadOp>, 245 LoadOpOfSubViewFolder<vector::TransferReadOp>, 246 StoreOpOfSubViewFolder<AffineStoreOp>, 247 StoreOpOfSubViewFolder<memref::StoreOp>, 248 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 249 patterns.getContext()); 250 } 251 252 //===----------------------------------------------------------------------===// 253 // Pass registration 254 //===----------------------------------------------------------------------===// 255 256 namespace { 257 258 #define GEN_PASS_CLASSES 259 #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" 260 261 struct FoldSubViewOpsPass final 262 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 263 void runOnOperation() override; 264 }; 265 266 } // namespace 267 268 void FoldSubViewOpsPass::runOnOperation() { 269 RewritePatternSet patterns(&getContext()); 270 memref::populateFoldSubViewOpPatterns(patterns); 271 (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), 272 std::move(patterns)); 273 } 274 275 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 276 return std::make_unique<FoldSubViewOpsPass>(); 277 } 278