//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transformation pass folds loading/storing from/to subview ops into // loading/storing from/to the original memref. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; //===----------------------------------------------------------------------===// // Utility functions //===----------------------------------------------------------------------===// /// Given the 'indices' of an load/store operation where the memref is a result /// of a subview op, returns the indices w.r.t to the source memref of the /// subview op. For example /// /// %0 = ... : memref<12x42xf32> /// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to /// memref<4x4xf32, offset=?, strides=[?, ?]> /// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]> /// /// could be folded into /// /// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] : /// memref<12x42xf32> static LogicalResult resolveSourceIndices(Location loc, PatternRewriter &rewriter, memref::SubViewOp subViewOp, ValueRange indices, SmallVectorImpl &sourceIndices) { SmallVector mixedOffsets = subViewOp.getMixedOffsets(); SmallVector mixedSizes = subViewOp.getMixedSizes(); SmallVector mixedStrides = subViewOp.getMixedStrides(); SmallVector useIndices; // Check if this is rank-reducing case. Then for every unit-dim size add a // zero to the indices. ArrayRef resultShape = subViewOp.getType().getShape(); unsigned resultDim = 0; for (auto size : llvm::enumerate(mixedSizes)) { auto attr = size.value().dyn_cast(); // Check if this dimension has been dropped, i.e. the size is 1, but the // associated dimension is not 1. if (attr && attr.cast().getInt() == 1 && (resultDim >= resultShape.size() || resultShape[resultDim] != 1)) useIndices.push_back(rewriter.create(loc, 0)); else if (resultDim < resultShape.size()) { useIndices.push_back(indices[resultDim++]); } } if (useIndices.size() != mixedOffsets.size()) return failure(); sourceIndices.resize(useIndices.size()); for (auto index : llvm::seq(0, mixedOffsets.size())) { SmallVector dynamicOperands; AffineExpr expr = rewriter.getAffineDimExpr(0); unsigned numSymbols = 0; dynamicOperands.push_back(useIndices[index]); // Multiply the stride; if (auto attr = mixedStrides[index].dyn_cast()) { expr = expr * attr.cast().getInt(); } else { dynamicOperands.push_back(mixedStrides[index].get()); expr = expr * rewriter.getAffineSymbolExpr(numSymbols++); } // Add the offset. if (auto attr = mixedOffsets[index].dyn_cast()) { expr = expr + attr.cast().getInt(); } else { dynamicOperands.push_back(mixedOffsets[index].get()); expr = expr + rewriter.getAffineSymbolExpr(numSymbols++); } Location loc = subViewOp.getLoc(); sourceIndices[index] = rewriter.create( loc, AffineMap::get(1, numSymbols, expr), dynamicOperands); } return success(); } /// Helpers to access the memref operand for each op. static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); } static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); } static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); } static Value getMemRefOperand(vector::TransferWriteOp op) { return op.source(); } //===----------------------------------------------------------------------===// // Patterns //===----------------------------------------------------------------------===// namespace { /// Merges subview operation with load/transferRead operation. template class LoadOpOfSubViewFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const override; private: void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const; }; /// Merges subview operation with store/transferWriteOp operation. template class StoreOpOfSubViewFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const override; private: void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const; }; template <> void LoadOpOfSubViewFolder::replaceOp( memref::LoadOp loadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(loadOp, subViewOp.source(), sourceIndices); } template <> void LoadOpOfSubViewFolder::replaceOp( vector::TransferReadOp loadOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices, loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr()); } template <> void StoreOpOfSubViewFolder::replaceOp( memref::StoreOp storeOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( storeOp, storeOp.value(), subViewOp.source(), sourceIndices); } template <> void StoreOpOfSubViewFolder::replaceOp( vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp, ArrayRef sourceIndices, PatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp( transferWriteOp, transferWriteOp.vector(), subViewOp.source(), sourceIndices, transferWriteOp.permutation_map(), transferWriteOp.in_boundsAttr()); } } // namespace template LogicalResult LoadOpOfSubViewFolder::matchAndRewrite(OpTy loadOp, PatternRewriter &rewriter) const { auto subViewOp = getMemRefOperand(loadOp).template getDefiningOp(); if (!subViewOp) return failure(); SmallVector sourceIndices; if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp, loadOp.indices(), sourceIndices))) return failure(); replaceOp(loadOp, subViewOp, sourceIndices, rewriter); return success(); } template LogicalResult StoreOpOfSubViewFolder::matchAndRewrite(OpTy storeOp, PatternRewriter &rewriter) const { auto subViewOp = getMemRefOperand(storeOp).template getDefiningOp(); if (!subViewOp) return failure(); SmallVector sourceIndices; if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp, storeOp.indices(), sourceIndices))) return failure(); replaceOp(storeOp, subViewOp, sourceIndices, rewriter); return success(); } void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) { patterns.add, LoadOpOfSubViewFolder, StoreOpOfSubViewFolder, StoreOpOfSubViewFolder>( patterns.getContext()); } //===----------------------------------------------------------------------===// // Pass registration //===----------------------------------------------------------------------===// namespace { #define GEN_PASS_CLASSES #include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" struct FoldSubViewOpsPass final : public FoldSubViewOpsBase { void runOnOperation() override; }; } // namespace void FoldSubViewOpsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); memref::populateFoldSubViewOpPatterns(patterns); (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(), std::move(patterns)); } std::unique_ptr memref::createFoldSubViewOpsPass() { return std::make_unique(); }