//===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"

using namespace mlir;

// For >1-D vector types, extracts the necessary information to iterate over all
// 1-D subvectors in the underlying llrepresentation of the n-D vector
// Iterates on the llvm array type until we hit a non-array type (which is
// asserted to be an llvm vector type).
LLVM::detail::NDVectorTypeInfo
LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
                                      LLVMTypeConverter &converter) {
  assert(vectorType.getRank() > 1 && "expected >1D vector type");
  NDVectorTypeInfo info;
  info.llvmNDVectorTy = converter.convertType(vectorType);
  if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
    info.llvmNDVectorTy = nullptr;
    return info;
  }
  info.arraySizes.reserve(vectorType.getRank() - 1);
  auto llvmTy = info.llvmNDVectorTy;
  while (llvmTy.isa<LLVM::LLVMArrayType>()) {
    info.arraySizes.push_back(
        llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
    llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
  }
  if (!LLVM::isCompatibleVectorType(llvmTy))
    return info;
  info.llvm1DVectorTy = llvmTy;
  return info;
}

// Express `linearIndex` in terms of coordinates of `basis`.
// Returns the empty vector when linearIndex is out of the range [0, P] where
// P is the product of all the basis coordinates.
//
// Prerequisites:
//   Basis is an array of nonnegative integers (signed type inherited from
//   vector shape type).
SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
                                                     unsigned linearIndex) {
  SmallVector<int64_t, 4> res;
  res.reserve(basis.size());
  for (unsigned basisElement : llvm::reverse(basis)) {
    res.push_back(linearIndex % basisElement);
    linearIndex = linearIndex / basisElement;
  }
  if (linearIndex > 0)
    return {};
  std::reverse(res.begin(), res.end());
  return res;
}

// Iterate of linear index, convert to coords space and insert splatted 1-D
// vector in each position.
void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
                                   OpBuilder &builder,
                                   function_ref<void(ArrayAttr)> fun) {
  unsigned ub = 1;
  for (auto s : info.arraySizes)
    ub *= s;
  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
    auto coords = getCoordinates(info.arraySizes, linearIndex);
    // Linear index is out of bounds, we are done.
    if (coords.empty())
      break;
    assert(coords.size() == info.arraySizes.size());
    auto position = builder.getI64ArrayAttr(coords);
    fun(position);
  }
}

LogicalResult LLVM::detail::handleMultidimensionalVectors(
    Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
    std::function<Value(Type, ValueRange)> createOperand,
    ConversionPatternRewriter &rewriter) {
  auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();

  SmallVector<Type> operand1DVectorTypes;
  for (Value operand : op->getOperands()) {
    auto operandNDVectorType = operand.getType().cast<VectorType>();
    auto operandTypeInfo =
        extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
    operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
  }
  auto resultTypeInfo =
      extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
  auto loc = op->getLoc();
  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
    // For this unrolled `position` corresponding to the `linearIndex`^th
    // element, extract operand vectors
    SmallVector<Value, 4> extractedOperands;
    for (auto operand : llvm::enumerate(operands)) {
      extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
          loc, operand1DVectorTypes[operand.index()], operand.value(),
          position));
    }
    Value newVal = createOperand(result1DVectorTy, extractedOperands);
    desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
                                                newVal, position);
  });
  rewriter.replaceOp(op, desc);
  return success();
}

LogicalResult LLVM::detail::vectorOneToOneRewrite(
    Operation *op, StringRef targetOp, ValueRange operands,
    LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
  assert(!operands.empty());

  // Cannot convert ops if their operands are not of LLVM type.
  if (!llvm::all_of(operands.getTypes(),
                    [](Type t) { return isCompatibleType(t); }))
    return failure();

  auto llvmNDVectorTy = operands[0].getType();
  if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
    return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);

  auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
                                            ValueRange operands) {
    OperationState state(op->getLoc(), targetOp);
    state.addTypes(llvm1DVectorTy);
    state.addOperands(operands);
    state.addAttributes(op->getAttrs());
    return rewriter.createOperation(state)->getResult(0);
  };

  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
                                       rewriter);
}
