//===- LegalizeForLLVMExport.cpp - Prepare ArmSVE for LLVM translation ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::arm_sve;

// Extract an LLVM IR type from the LLVM IR dialect type.
static Type unwrap(Type type) {
  if (!type)
    return nullptr;
  auto *mlirContext = type.getContext();
  if (!LLVM::isCompatibleType(type))
    emitError(UnknownLoc::get(mlirContext),
              "conversion resulted in a non-LLVM type");
  return type;
}

static Optional<Type>
convertScalableVectorTypeToLLVM(ScalableVectorType svType,
                                LLVMTypeConverter &converter) {
  auto elementType = unwrap(converter.convertType(svType.getElementType()));
  if (!elementType)
    return {};

  auto sVectorType =
      LLVM::LLVMScalableVectorType::get(elementType, svType.getShape().back());
  return sVectorType;
}

template <typename OpTy>
class ForwardOperands : public OpConversionPattern<OpTy> {
  using OpConversionPattern<OpTy>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(OpTy op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    if (ValueRange(operands).getTypes() == op->getOperands().getTypes())
      return rewriter.notifyMatchFailure(op, "operand types already match");

    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
    return success();
  }
};

class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
  using OpConversionPattern<ReturnOp>::OpConversionPattern;

  LogicalResult
  matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const final {
    rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); });
    return success();
  }
};

static Optional<Value> addUnrealizedCast(OpBuilder &builder,
                                         ScalableVectorType svType,
                                         ValueRange inputs, Location loc) {
  if (inputs.size() != 1 ||
      !inputs[0].getType().isa<LLVM::LLVMScalableVectorType>())
    return Value();
  return builder.create<UnrealizedConversionCastOp>(loc, svType, inputs)
      .getResult(0);
}

using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using VectorScaleOpLowering =
    OneToOneConvertToLLVMPattern<VectorScaleOp, VectorScaleIntrOp>;
using ScalableMaskedAddIOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
                                 ScalableMaskedAddIIntrOp>;
using ScalableMaskedAddFOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedAddFOp,
                                 ScalableMaskedAddFIntrOp>;
using ScalableMaskedSubIOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedSubIOp,
                                 ScalableMaskedSubIIntrOp>;
using ScalableMaskedSubFOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedSubFOp,
                                 ScalableMaskedSubFIntrOp>;
using ScalableMaskedMulIOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedMulIOp,
                                 ScalableMaskedMulIIntrOp>;
using ScalableMaskedMulFOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedMulFOp,
                                 ScalableMaskedMulFIntrOp>;
using ScalableMaskedSDivIOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedSDivIOp,
                                 ScalableMaskedSDivIIntrOp>;
using ScalableMaskedUDivIOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedUDivIOp,
                                 ScalableMaskedUDivIIntrOp>;
using ScalableMaskedDivFOpLowering =
    OneToOneConvertToLLVMPattern<ScalableMaskedDivFOp,
                                 ScalableMaskedDivFIntrOp>;

// Load operation is lowered to code that obtains a pointer to the indexed
// element and loads from it.
struct ScalableLoadOpLowering : public ConvertOpToLLVMPattern<ScalableLoadOp> {
  using ConvertOpToLLVMPattern<ScalableLoadOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(ScalableLoadOp loadOp, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto type = loadOp.getMemRefType();
    if (!isConvertibleAndHasIdentityMaps(type))
      return failure();

    ScalableLoadOp::Adaptor transformed(operands);
    LLVMTypeConverter converter(loadOp.getContext());

    auto resultType = loadOp.result().getType();
    LLVM::LLVMPointerType llvmDataTypePtr;
    if (resultType.isa<VectorType>()) {
      llvmDataTypePtr =
          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
    } else if (resultType.isa<ScalableVectorType>()) {
      llvmDataTypePtr = LLVM::LLVMPointerType::get(
          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
                                          converter)
              .getValue());
    }
    Value dataPtr =
        getStridedElementPtr(loadOp.getLoc(), type, transformed.base(),
                             transformed.index(), rewriter);
    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
        loadOp.getLoc(), llvmDataTypePtr, dataPtr);
    rewriter.replaceOpWithNewOp<LLVM::LoadOp>(loadOp, bitCastedPtr);
    return success();
  }
};

// Store operation is lowered to code that obtains a pointer to the indexed
// element, and stores the given value to it.
struct ScalableStoreOpLowering
    : public ConvertOpToLLVMPattern<ScalableStoreOp> {
  using ConvertOpToLLVMPattern<ScalableStoreOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(ScalableStoreOp storeOp, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto type = storeOp.getMemRefType();
    if (!isConvertibleAndHasIdentityMaps(type))
      return failure();

    ScalableStoreOp::Adaptor transformed(operands);
    LLVMTypeConverter converter(storeOp.getContext());

    auto resultType = storeOp.value().getType();
    LLVM::LLVMPointerType llvmDataTypePtr;
    if (resultType.isa<VectorType>()) {
      llvmDataTypePtr =
          LLVM::LLVMPointerType::get(resultType.cast<VectorType>());
    } else if (resultType.isa<ScalableVectorType>()) {
      llvmDataTypePtr = LLVM::LLVMPointerType::get(
          convertScalableVectorTypeToLLVM(resultType.cast<ScalableVectorType>(),
                                          converter)
              .getValue());
    }
    Value dataPtr =
        getStridedElementPtr(storeOp.getLoc(), type, transformed.base(),
                             transformed.index(), rewriter);
    Value bitCastedPtr = rewriter.create<LLVM::BitcastOp>(
        storeOp.getLoc(), llvmDataTypePtr, dataPtr);
    rewriter.replaceOpWithNewOp<LLVM::StoreOp>(storeOp, transformed.value(),
                                               bitCastedPtr);
    return success();
  }
};

static void
populateBasicSVEArithmeticExportPatterns(LLVMTypeConverter &converter,
                                         OwningRewritePatternList &patterns) {
  // clang-format off
  patterns.add<OneToOneConvertToLLVMPattern<ScalableAddIOp, LLVM::AddOp>,
               OneToOneConvertToLLVMPattern<ScalableAddFOp, LLVM::FAddOp>,
               OneToOneConvertToLLVMPattern<ScalableSubIOp, LLVM::SubOp>,
               OneToOneConvertToLLVMPattern<ScalableSubFOp, LLVM::FSubOp>,
               OneToOneConvertToLLVMPattern<ScalableMulIOp, LLVM::MulOp>,
               OneToOneConvertToLLVMPattern<ScalableMulFOp, LLVM::FMulOp>,
               OneToOneConvertToLLVMPattern<ScalableSDivIOp, LLVM::SDivOp>,
               OneToOneConvertToLLVMPattern<ScalableUDivIOp, LLVM::UDivOp>,
               OneToOneConvertToLLVMPattern<ScalableDivFOp, LLVM::FDivOp>
              >(converter);
  // clang-format on
}

static void
configureBasicSVEArithmeticLegalizations(LLVMConversionTarget &target) {
  // clang-format off
  target.addIllegalOp<ScalableAddIOp,
                      ScalableAddFOp,
                      ScalableSubIOp,
                      ScalableSubFOp,
                      ScalableMulIOp,
                      ScalableMulFOp,
                      ScalableSDivIOp,
                      ScalableUDivIOp,
                      ScalableDivFOp>();
  // clang-format on
}

static void
populateSVEMaskGenerationExportPatterns(LLVMTypeConverter &converter,
                                        OwningRewritePatternList &patterns) {
  // clang-format off
  patterns.add<OneToOneConvertToLLVMPattern<ScalableCmpFOp, LLVM::FCmpOp>,
               OneToOneConvertToLLVMPattern<ScalableCmpIOp, LLVM::ICmpOp>
              >(converter);
  // clang-format on
}

static void
configureSVEMaskGenerationLegalizations(LLVMConversionTarget &target) {
  // clang-format off
  target.addIllegalOp<ScalableCmpFOp,
                      ScalableCmpIOp>();
  // clang-format on
}

/// Populate the given list with patterns that convert from ArmSVE to LLVM.
void mlir::populateArmSVELegalizeForLLVMExportPatterns(
    LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
  // Populate conversion patterns
  // Remove any ArmSVE-specific types from function signatures and results.
  populateFuncOpTypeConversionPattern(patterns, converter);
  converter.addConversion([&converter](ScalableVectorType svType) {
    return convertScalableVectorTypeToLLVM(svType, converter);
  });
  converter.addSourceMaterialization(addUnrealizedCast);

  // clang-format off
  patterns.add<ForwardOperands<CallOp>,
               ForwardOperands<CallIndirectOp>,
               ForwardOperands<ReturnOp>>(converter,
                                          &converter.getContext());
  patterns.add<SdotOpLowering,
               SmmlaOpLowering,
               UdotOpLowering,
               UmmlaOpLowering,
               VectorScaleOpLowering,
               ScalableMaskedAddIOpLowering,
               ScalableMaskedAddFOpLowering,
               ScalableMaskedSubIOpLowering,
               ScalableMaskedSubFOpLowering,
               ScalableMaskedMulIOpLowering,
               ScalableMaskedMulFOpLowering,
               ScalableMaskedSDivIOpLowering,
               ScalableMaskedUDivIOpLowering,
               ScalableMaskedDivFOpLowering>(converter);
  patterns.add<ScalableLoadOpLowering,
               ScalableStoreOpLowering>(converter);
  // clang-format on
  populateBasicSVEArithmeticExportPatterns(converter, patterns);
  populateSVEMaskGenerationExportPatterns(converter, patterns);
}

void mlir::configureArmSVELegalizeForExportTarget(
    LLVMConversionTarget &target) {
  // clang-format off
  target.addLegalOp<SdotIntrOp,
                    SmmlaIntrOp,
                    UdotIntrOp,
                    UmmlaIntrOp,
                    VectorScaleIntrOp,
                    ScalableMaskedAddIIntrOp,
                    ScalableMaskedAddFIntrOp,
                    ScalableMaskedSubIIntrOp,
                    ScalableMaskedSubFIntrOp,
                    ScalableMaskedMulIIntrOp,
                    ScalableMaskedMulFIntrOp,
                    ScalableMaskedSDivIIntrOp,
                    ScalableMaskedUDivIIntrOp,
                    ScalableMaskedDivFIntrOp>();
  target.addIllegalOp<SdotOp,
                      SmmlaOp,
                      UdotOp,
                      UmmlaOp,
                      VectorScaleOp,
                      ScalableMaskedAddIOp,
                      ScalableMaskedAddFOp,
                      ScalableMaskedSubIOp,
                      ScalableMaskedSubFOp,
                      ScalableMaskedMulIOp,
                      ScalableMaskedMulFOp,
                      ScalableMaskedSDivIOp,
                      ScalableMaskedUDivIOp,
                      ScalableMaskedDivFOp,
                      ScalableLoadOp,
                      ScalableStoreOp>();
  // clang-format on
  auto hasScalableVectorType = [](TypeRange types) {
    for (Type type : types)
      if (type.isa<arm_sve::ScalableVectorType>())
        return true;
    return false;
  };
  target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
    return !hasScalableVectorType(op.getType().getInputs()) &&
           !hasScalableVectorType(op.getType().getResults());
  });
  target.addDynamicallyLegalOp<CallOp, CallIndirectOp, ReturnOp>(
      [hasScalableVectorType](Operation *op) {
        return !hasScalableVectorType(op->getOperandTypes()) &&
               !hasScalableVectorType(op->getResultTypes());
      });
  configureBasicSVEArithmeticLegalizations(target);
  configureSVEMaskGenerationLegalizations(target);
}
