//===- FoldSubViewOps.cpp - Fold memref.subview ops -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This transformation pass folds loading/storing from/to subview ops into
// loading/storing from/to the original memref.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Given the 'indices' of an load/store operation where the memref is a result
/// of a subview op, returns the indices w.r.t to the source memref of the
/// subview op. For example
///
/// %0 = ... : memref<12x42xf32>
/// %1 = subview %0[%arg0, %arg1][][%stride1, %stride2] : memref<12x42xf32> to
///          memref<4x4xf32, offset=?, strides=[?, ?]>
/// %2 = load %1[%i1, %i2] : memref<4x4xf32, offset=?, strides=[?, ?]>
///
/// could be folded into
///
/// %2 = load %0[%arg0 + %i1 * %stride1][%arg1 + %i2 * %stride2] :
///          memref<12x42xf32>
static LogicalResult
resolveSourceIndices(Location loc, PatternRewriter &rewriter,
                     memref::SubViewOp subViewOp, ValueRange indices,
                     SmallVectorImpl<Value> &sourceIndices) {
  SmallVector<OpFoldResult> mixedOffsets = subViewOp.getMixedOffsets();
  SmallVector<OpFoldResult> mixedSizes = subViewOp.getMixedSizes();
  SmallVector<OpFoldResult> mixedStrides = subViewOp.getMixedStrides();

  SmallVector<Value> useIndices;
  // Check if this is rank-reducing case. Then for every unit-dim size add a
  // zero to the indices.
  ArrayRef<int64_t> resultShape = subViewOp.getType().getShape();
  unsigned resultDim = 0;
  for (auto size : llvm::enumerate(mixedSizes)) {
    auto attr = size.value().dyn_cast<Attribute>();
    // Check if this dimension has been dropped, i.e. the size is 1, but the
    // associated dimension is not 1.
    if (attr && attr.cast<IntegerAttr>().getInt() == 1 &&
        (resultDim >= resultShape.size() || resultShape[resultDim] != 1))
      useIndices.push_back(rewriter.create<ConstantIndexOp>(loc, 0));
    else if (resultDim < resultShape.size()) {
      useIndices.push_back(indices[resultDim++]);
    }
  }
  if (useIndices.size() != mixedOffsets.size())
    return failure();
  sourceIndices.resize(useIndices.size());
  for (auto index : llvm::seq<size_t>(0, mixedOffsets.size())) {
    SmallVector<Value> dynamicOperands;
    AffineExpr expr = rewriter.getAffineDimExpr(0);
    unsigned numSymbols = 0;
    dynamicOperands.push_back(useIndices[index]);

    // Multiply the stride;
    if (auto attr = mixedStrides[index].dyn_cast<Attribute>()) {
      expr = expr * attr.cast<IntegerAttr>().getInt();
    } else {
      dynamicOperands.push_back(mixedStrides[index].get<Value>());
      expr = expr * rewriter.getAffineSymbolExpr(numSymbols++);
    }

    // Add the offset.
    if (auto attr = mixedOffsets[index].dyn_cast<Attribute>()) {
      expr = expr + attr.cast<IntegerAttr>().getInt();
    } else {
      dynamicOperands.push_back(mixedOffsets[index].get<Value>());
      expr = expr + rewriter.getAffineSymbolExpr(numSymbols++);
    }
    Location loc = subViewOp.getLoc();
    sourceIndices[index] = rewriter.create<AffineApplyOp>(
        loc, AffineMap::get(1, numSymbols, expr), dynamicOperands);
  }
  return success();
}

/// Helpers to access the memref operand for each op.
static Value getMemRefOperand(memref::LoadOp op) { return op.memref(); }

static Value getMemRefOperand(vector::TransferReadOp op) { return op.source(); }

static Value getMemRefOperand(memref::StoreOp op) { return op.memref(); }

static Value getMemRefOperand(vector::TransferWriteOp op) {
  return op.source();
}

//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//

namespace {
/// Merges subview operation with load/transferRead operation.
template <typename OpTy>
class LoadOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy loadOp,
                                PatternRewriter &rewriter) const override;

private:
  void replaceOp(OpTy loadOp, memref::SubViewOp subViewOp,
                 ArrayRef<Value> sourceIndices,
                 PatternRewriter &rewriter) const;
};

/// Merges subview operation with store/transferWriteOp operation.
template <typename OpTy>
class StoreOpOfSubViewFolder final : public OpRewritePattern<OpTy> {
public:
  using OpRewritePattern<OpTy>::OpRewritePattern;

  LogicalResult matchAndRewrite(OpTy storeOp,
                                PatternRewriter &rewriter) const override;

private:
  void replaceOp(OpTy storeOp, memref::SubViewOp subViewOp,
                 ArrayRef<Value> sourceIndices,
                 PatternRewriter &rewriter) const;
};

template <>
void LoadOpOfSubViewFolder<memref::LoadOp>::replaceOp(
    memref::LoadOp loadOp, memref::SubViewOp subViewOp,
    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
  rewriter.replaceOpWithNewOp<memref::LoadOp>(loadOp, subViewOp.source(),
                                              sourceIndices);
}

template <>
void LoadOpOfSubViewFolder<vector::TransferReadOp>::replaceOp(
    vector::TransferReadOp loadOp, memref::SubViewOp subViewOp,
    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
  rewriter.replaceOpWithNewOp<vector::TransferReadOp>(
      loadOp, loadOp.getVectorType(), subViewOp.source(), sourceIndices,
      loadOp.permutation_map(), loadOp.padding(), loadOp.in_boundsAttr());
}

template <>
void StoreOpOfSubViewFolder<memref::StoreOp>::replaceOp(
    memref::StoreOp storeOp, memref::SubViewOp subViewOp,
    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
  rewriter.replaceOpWithNewOp<memref::StoreOp>(
      storeOp, storeOp.value(), subViewOp.source(), sourceIndices);
}

template <>
void StoreOpOfSubViewFolder<vector::TransferWriteOp>::replaceOp(
    vector::TransferWriteOp transferWriteOp, memref::SubViewOp subViewOp,
    ArrayRef<Value> sourceIndices, PatternRewriter &rewriter) const {
  rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
      transferWriteOp, transferWriteOp.vector(), subViewOp.source(),
      sourceIndices, transferWriteOp.permutation_map(),
      transferWriteOp.in_boundsAttr());
}
} // namespace

template <typename OpTy>
LogicalResult
LoadOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy loadOp,
                                             PatternRewriter &rewriter) const {
  auto subViewOp =
      getMemRefOperand(loadOp).template getDefiningOp<memref::SubViewOp>();
  if (!subViewOp)
    return failure();

  SmallVector<Value, 4> sourceIndices;
  if (failed(resolveSourceIndices(loadOp.getLoc(), rewriter, subViewOp,
                                  loadOp.indices(), sourceIndices)))
    return failure();

  replaceOp(loadOp, subViewOp, sourceIndices, rewriter);
  return success();
}

template <typename OpTy>
LogicalResult
StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
                                              PatternRewriter &rewriter) const {
  auto subViewOp =
      getMemRefOperand(storeOp).template getDefiningOp<memref::SubViewOp>();
  if (!subViewOp)
    return failure();

  SmallVector<Value, 4> sourceIndices;
  if (failed(resolveSourceIndices(storeOp.getLoc(), rewriter, subViewOp,
                                  storeOp.indices(), sourceIndices)))
    return failure();

  replaceOp(storeOp, subViewOp, sourceIndices, rewriter);
  return success();
}

void memref::populateFoldSubViewOpPatterns(RewritePatternSet &patterns) {
  patterns.add<LoadOpOfSubViewFolder<memref::LoadOp>,
               LoadOpOfSubViewFolder<vector::TransferReadOp>,
               StoreOpOfSubViewFolder<memref::StoreOp>,
               StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
      patterns.getContext());
}

//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//

namespace {

#define GEN_PASS_CLASSES
#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc"

struct FoldSubViewOpsPass final
    : public FoldSubViewOpsBase<FoldSubViewOpsPass> {
  void runOnOperation() override;
};

} // namespace

void FoldSubViewOpsPass::runOnOperation() {
  RewritePatternSet patterns(&getContext());
  memref::populateFoldSubViewOpPatterns(patterns);
  (void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
                                     std::move(patterns));
}

std::unique_ptr<Pass> memref::createFoldSubViewOpsPass() {
  return std::make_unique<FoldSubViewOpsPass>();
}
