1 //===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass folds loading/storing from/to subview ops into 10 // loading/storing from/to the original memref. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/Affine/IR/AffineOps.h" 16 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 17 #include "mlir/Dialect/MemRef/IR/MemRef.h" 18 #include "mlir/Dialect/MemRef/Transforms/Passes.h" 19 #include "mlir/Dialect/Vector/IR/VectorOps.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 22 #include "llvm/ADT/SmallBitVector.h" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Utility functions 28 //===----------------------------------------------------------------------===// 29 30 /// Given the 'indices' of an load/store operation where the memref is a result 31 /// of a subview op, returns the indices w.r.t to the source memref of the 32 /// subview op. For example 33 /// 34 /// %0 = ... : memref<12x42xf32> 35 /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to 36 /// memref<4x4xf32, offset=?, strides=[?, ?]> 37 /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> 38 /// 39 /// could be folded into 40 /// 41 /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : 42 /// memref<12x42xf32> 43 static LogicalResult 44 resolveSourceIndices(Location loc, PatternRewriter &rewriter, 45 memref::SubViewOp subViewOp, ValueRange indices, 46 SmallVectorImpl<Value> &sourceIndices) { 47 SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets(); 48 SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes(); 49 SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides(); 50 51 SmallVector<Value> useIndices; 52 // Check if this is rank-reducing case. Then for every unit-dim size add a 53 // zero to the indices. 54 unsigned resultDim = 0; 55 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 56 for (auto dim : llvm::seq<unsigned>(0, subViewOp.getSourceType().getRank())) { 57 if (unusedDims.test(dim)) 58 useIndices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0)); 59 else 60 useIndices.push_back(indices[resultDim++]); 61 } 62 if (useIndices.size() != mixedOffsets.size()) 63 return failure(); 64 sourceIndices.resize(useIndices.size()); 65 for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) { 66 SmallVector<Value> dynamicOperands; 67 AffineExpr expr = rewriter.getAffineDimExpr(0); 68 unsigned numSymbols = 0; 69 dynamicOperands.push_back(useIndices[index]); 70 71 // Multiply the stride; 72 if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) { 73 expr = expr * attr.cast<IntegerAttr>().getInt(); 74 } else { 75 dynamicOperands.push_back(mixedStrides[index].get<Value>()); 76 expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); 77 } 78 79 // Add the offset. 80 if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) { 81 expr = expr + attr.cast<IntegerAttr>().getInt(); 82 } else { 83 dynamicOperands.push_back(mixedOffsets[index].get<Value>()); 84 expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); 85 } 86 Location loc = subViewOp.getLoc(); 87 sourceIndices[index] = rewriter.create<AffineApplyOp>( 88 loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); 89 } 90 return success(); 91 } 92 93 /// Helpers to access the memref operand for each op. 94 template <typename LoadOrStoreOpTy> 95 static Value getMemRefOperand(LoadOrStoreOpTy op) { 96 return op.memref(); 97 } 98 99 static Value getMemRefOperand(vector::TransferReadOp op) { 100 return op.getSource(); 101 } 102 103 static Value getMemRefOperand(vector::TransferWriteOp op) { 104 return op.getSource(); 105 } 106 107 /// Given the permutation map of the original 108 /// `vector.transfer_read`/`vector.transfer_write` operations compute the 109 /// permutation map to use after the subview is folded with it. 110 static AffineMapAttr getPermutationMapAttr(MLIRContext *context, 111 memref::SubViewOp subViewOp, 112 AffineMap currPermutationMap) { 113 llvm::SmallBitVector unusedDims = subViewOp.getDroppedDims(); 114 SmallVector<AffineExpr> exprs; 115 int64_t sourceRank = subViewOp.getSourceType().getRank(); 116 for (auto dim : llvm::seq<int64_t>(0, sourceRank)) { 117 if (unusedDims.test(dim)) 118 continue; 119 exprs.push_back(getAffineDimExpr(dim, context)); 120 } 121 auto resultDimToSourceDimMap = AffineMap::get(sourceRank, 0, exprs, context); 122 return AffineMapAttr::get( 123 currPermutationMap.compose(resultDimToSourceDimMap)); 124 } 125 126 //===----------------------------------------------------------------------===// 127 // Patterns 128 //===----------------------------------------------------------------------===// 129 130 namespace { 131 /// Merges subview operation with load/transferRead operation. 132 template <typename OpTy> 133 class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 134 public: 135 using OpRewritePattern<OpTy>::OpRewritePattern; 136 137 LogicalResult matchAndRewrite(OpTy loadOp, 138 PatternRewriter &rewriter) const override; 139 140 private: 141 void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, 142 ArrayRef<Value> sourceIndices, 143 PatternRewriter &rewriter) const; 144 }; 145 146 /// Merges subview operation with store/transferWriteOp operation. 147 template <typename OpTy> 148 class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> { 149 public: 150 using OpRewritePattern<OpTy>::OpRewritePattern; 151 152 LogicalResult matchAndRewrite(OpTy storeOp, 153 PatternRewriter &rewriter) const override; 154 155 private: 156 void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, 157 ArrayRef<Value> sourceIndices, 158 PatternRewriter &rewriter) const; 159 }; 160 161 template <typename LoadOpTy> 162 void LoadOpOfSubViewFolder<LoadOpTy>::replaceOp( 163 LoadOpTy loadOp, memref::SubViewOp subViewOp, ArrayRef<Value> sourceIndices, 164 PatternRewriter &rewriter) const { 165 rewriter.replaceOpWithNewOp<LoadOpTy>(loadOp, subViewOp.source(), 166 sourceIndices); 167 } 168 169 template <> 170 void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp( 171 vector::TransferReadOp transferReadOp, memref::SubViewOp subViewOp, 172 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 173 // TODO: support 0-d corner case. 174 if (transferReadOp.getTransferRank() == 0) 175 return; 176 rewriter.replaceOpWithNewOp<vector::TransferReadOp>( 177 transferReadOp, transferReadOp.getVectorType(), subViewOp.source(), 178 sourceIndices, 179 getPermutationMapAttr(rewriter.getContext(), subViewOp, 180 transferReadOp.getPermutationMap()), 181 transferReadOp.getPadding(), 182 /*mask=*/Value(), transferReadOp.getInBoundsAttr()); 183 } 184 185 template <typename StoreOpTy> 186 void StoreOpOfSubViewFolder<StoreOpTy>::replaceOp( 187 StoreOpTy storeOp, memref::SubViewOp subViewOp, 188 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 189 rewriter.replaceOpWithNewOp<StoreOpTy>(storeOp, storeOp.value(), 190 subViewOp.source(), sourceIndices); 191 } 192 193 template <> 194 void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp( 195 vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, 196 ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const { 197 // TODO: support 0-d corner case. 198 if (transferWriteOp.getTransferRank() == 0) 199 return; 200 rewriter.replaceOpWithNewOp<vector::TransferWriteOp>( 201 transferWriteOp, transferWriteOp.getVector(), subViewOp.source(), 202 sourceIndices, 203 getPermutationMapAttr(rewriter.getContext(), subViewOp, 204 transferWriteOp.getPermutationMap()), 205 transferWriteOp.getInBoundsAttr()); 206 } 207 } // namespace 208 209 template <typename OpTy> 210 LogicalResult 211 LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp, 212 PatternRewriter &rewriter) const { 213 auto subViewOp = 214 getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>(); 215 if (!subViewOp) 216 return failure(); 217 218 SmallVector<Value, 4> sourceIndices; 219 if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, 220 loadOp.getIndices(), sourceIndices))) 221 return failure(); 222 223 replaceOp(loadOp, subViewOp, sourceIndices, rewriter); 224 return success(); 225 } 226 227 template <typename OpTy> 228 LogicalResult 229 StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp, 230 PatternRewriter &rewriter) const { 231 auto subViewOp = 232 getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>(); 233 if (!subViewOp) 234 return failure(); 235 236 SmallVector<Value, 4> sourceIndices; 237 if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, 238 storeOp.getIndices(), sourceIndices))) 239 return failure(); 240 241 replaceOp(storeOp, subViewOp, sourceIndices, rewriter); 242 return success(); 243 } 244 245 void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { 246 patterns.add<LoadOpOfSubViewFolder<AffineLoadOp>, 247 LoadOpOfSubViewFolder<memref::LoadOp>, 248 LoadOpOfSubViewFolder<vector::TransferReadOp>, 249 StoreOpOfSubViewFolder<AffineStoreOp>, 250 StoreOpOfSubViewFolder<memref::StoreOp>, 251 StoreOpOfSubViewFolder<vector::TransferWriteOp>>( 252 patterns.getContext()); 253 } 254 255 //===----------------------------------------------------------------------===// 256 // Pass registration 257 //===----------------------------------------------------------------------===// 258 259 namespace { 260 261 struct FoldSubViewOpsPass final 262 : public FoldSubViewOpsBase<FoldSubViewOpsPass> { 263 void runOnOperation() override; 264 }; 265 266 } // namespace 267 268 void FoldSubViewOpsPass::runOnOperation() { 269 RewritePatternSet patterns(&getContext()); 270 memref::populateFoldSubViewOpPatterns(patterns); 271 (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); 272 } 273 274 std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() { 275 return std::make_unique<FoldSubViewOpsPass>(); 276 } 277