//===- TilingInterfaceImpl.cpp - Implementation of TilingInterface -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/Arithmetic/Utils/Utils.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/TilingInterface.h"

using namespace mlir;
using namespace mlir::linalg;

namespace {

/// External model implementation of TilingInterface for LinalgOps. An external
/// model implementation is used for now till the use of `TilingInterface` is
/// on-par with the current Linalg tiling + fusion patterns. Once it is
/// maybe possible to move this into the op-definition (though there are
/// advantages to leaving it as an external model)
template <typename LinalgOpTy>
struct LinalgOpTilingInterface
    : public TilingInterface::ExternalModel<LinalgOpTilingInterface<LinalgOpTy>,
                                            LinalgOpTy> {
  /// Return the destination operands.
  SmallVector<Value> getDestinationOperands(Operation *op, OpBuilder &b) const {
    return llvm::cast<LinalgOp>(op).getOutputOperands();
  }

  /// Return the loop iterator type.
  SmallVector<StringRef> getLoopIteratorTypes(Operation *op) const {
    LinalgOpTy concreteOp = cast<LinalgOpTy>(op);
    return llvm::to_vector(
        llvm::map_range(concreteOp.iterator_types(), [](Attribute strAttr) {
          return strAttr.cast<StringAttr>().getValue();
        }));
  }

  /// Return the iteration domain range.
  SmallVector<Range> getIterationDomain(Operation *op, OpBuilder &b) const {
    OpBuilder::InsertionGuard g(b);
    b.setInsertionPoint(op);
    Location loc = op->getLoc();
    LinalgOp linalgOp = cast<LinalgOp>(op);
    auto allShapesSizes = linalgOp.createFlatListOfOperandDims(b, loc);
    AffineMap map = linalgOp.getShapesToLoopsMap();
    Value zero = b.create<arith::ConstantIndexOp>(loc, 0);
    Value one = b.create<arith::ConstantIndexOp>(loc, 1);
    return llvm::to_vector(llvm::map_range(
        applyMapToValues(b, loc, map, allShapesSizes), [&](Value v) {
          return Range{zero, v, one};
        }));
  }

  // Instantiate the tiled implementation of the operation.
  SmallVector<Operation *>
  getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
                         bool tileDestOperands) const {
    // Leave the `sizeBounds` value empty. That is only needed when the `sizes`
    // specified could lead to out of bounds accesses.
    Location loc = op->getLoc();
    LinalgOp linalgOp = cast<LinalgOp>(op);
    SmallVector<Value> valuesToTile = linalgOp.getInputAndOutputOperands();
    SmallVector<Value> offsetValues =
        getValueOrCreateConstantIndexOp(b, loc, offsets);
    SmallVector<Value, 4> tiledOperands = makeTiledShapes(
        b, loc, linalgOp, valuesToTile, offsetValues,
        getValueOrCreateConstantIndexOp(b, loc, sizes), {}, true);

    SmallVector<Type> resultTensorTypes = llvm::to_vector(llvm::map_range(
        linalgOp.getOutputTensorOperands(), [&](OpOperand *opOperand) {
          return tiledOperands[opOperand->getOperandNumber()].getType();
        }));

    Operation *tiledOp =
        linalgOp.clone(b, loc, resultTensorTypes, tiledOperands);
    offsetIndices(b, cast<LinalgOp>(tiledOp), offsetValues);

    return {tiledOp};
  }

  // Return the details of the output tile generated by the tiled
  // implementation.
  LogicalResult
  getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                        ArrayRef<OpFoldResult> offsets,
                        ArrayRef<OpFoldResult> sizes,
                        SmallVector<OpFoldResult> &resultOffsets,
                        SmallVector<OpFoldResult> &resultSizes) const {
    Location loc = op->getLoc();
    LinalgOp linalgOp = cast<LinalgOp>(op);

    AffineExpr d0;
    bindDims(b.getContext(), d0);

    auto fullyComposeAffineMapAndOperands = [](OpBuilder &builder, Location loc,
                                               AffineExpr expr,
                                               ValueRange operands) -> Value {
      AffineMap map = AffineMap::inferFromExprList({expr}).front();
      SmallVector<Value> normalizedOperands(operands.begin(), operands.end());
      mlir::fullyComposeAffineMapAndOperands(&map, &normalizedOperands);
      canonicalizeMapAndOperands(&map, &normalizedOperands);
      return builder.createOrFold<AffineApplyOp>(loc, map, normalizedOperands);
    };

    SmallVector<Value> sizeVals =
        getValueOrCreateConstantIndexOp(b, loc, sizes);
    SmallVector<Value> subShapeSizes =
        llvm::to_vector(llvm::map_range(sizeVals, [&](Value v) {
          return fullyComposeAffineMapAndOperands(b, loc, d0 - 1, v);
        }));
    OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
    Value sliceOpResult =
        makeTiledShape(b, loc, outOperand->get(), sizeVals,
                       linalgOp.getTiedIndexingMap(outOperand),
                       getValueOrCreateConstantIndexOp(b, loc, offsets),
                       /*ubs*/ {}, subShapeSizes, true);
    auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
    if (!sliceOp)
      return failure();
    resultOffsets = sliceOp.getMixedOffsets();
    resultSizes = sliceOp.getMixedSizes();
    return success();
  }

  FailureOr<Value> generateResultTileValue(Operation *op, OpBuilder &b,
                                           unsigned resultNumber,
                                           ValueRange dest,
                                           ArrayRef<OpFoldResult> offsets,
                                           ArrayRef<OpFoldResult> sizes,
                                           bool tileDestOperands) const {
    auto linalgOp = cast<LinalgOp>(op);

    // Check that the indexing map used for the output is a projected
    // permutation. This could be relaxed with a more general approach that can
    // map the offsets and sizes from the result to iteration space tiles
    // (filling in full extent for dimensions not used to access the result).
    AffineMap indexingMap =
        linalgOp.getTiedIndexingMapForResult(op->getResult(resultNumber));
    if (!indexingMap.isProjectedPermutation()) {
      return op->emitOpError(
          "unhandled tiled implementation generation when result is not "
          "accessed using a permuted projection");
    }

    auto numLoops = linalgOp.getNumLoops();
    auto tilingInterfaceOp = cast<TilingInterface>(op);
    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
        iterationTileSizes(numLoops);
    if (!indexingMap.isPermutation()) {
      SmallVector<Range> iterationDomain =
          tilingInterfaceOp.getIterationDomain(b);
      for (auto range : llvm::enumerate(iterationDomain)) {
        iterationTileOffsets[range.index()] = range.value().offset;
        iterationTileSizes[range.index()] = range.value().size;
      }
    }
    for (auto resultExpr : llvm::enumerate(indexingMap.getResults())) {
      unsigned dimPosition =
          resultExpr.value().cast<AffineDimExpr>().getPosition();
      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
    }

    SmallVector<Operation *> tiledOp = tilingInterfaceOp.getTiledImplementation(
        b, dest, iterationTileOffsets, iterationTileSizes, tileDestOperands);
    if (tiledOp.size() != 1)
      return op->emitOpError("failed to generate tiled implementation");

    return tiledOp[0]->getResult(resultNumber);
  }
};

} // namespace

template <typename OpType>
static void registerOne(MLIRContext *ctx) {
  OpType::template attachInterface<LinalgOpTilingInterface<OpType>>(*ctx);
}

/// Variadic helper function.
template <typename... OpTypes>
static void registerAll(MLIRContext *ctx) {
  // FIXME: In c++17 this can be simplified by using 'fold expressions'.
  (void)std::initializer_list<int>{0, (registerOne<OpTypes>(ctx), 0)...};
}

#define GET_OP_LIST

void mlir::linalg::registerTilingInterfaceExternalModels(
    DialectRegistry &registry) {
  registry.addExtension(+[](MLIRContext *ctx, linalg::LinalgDialect *dialect) {
    registerOne<linalg::GenericOp>(ctx);
    registerAll<
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
        >(ctx);
  });
}
