//===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/X86Vector/Transforms.h"

#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::x86vector;

/// Extracts the "main" vector element type from the given X86Vector operation.
template <typename OpTy>
static Type getSrcVectorElementType(OpTy op) {
  return op.src().getType().template cast<VectorType>().getElementType();
}
template <>
Type getSrcVectorElementType(Vp2IntersectOp op) {
  return op.a().getType().template cast<VectorType>().getElementType();
}

namespace {

/// Base conversion for AVX512 ops that can be lowered to one of the two
/// intrinsics based on the bitwidth of their "main" vector element type. This
/// relies on the to-LLVM-dialect conversion helpers to correctly pack the
/// results of multi-result intrinsic ops.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
  explicit LowerToIntrinsic(LLVMTypeConverter &converter)
      : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}

  LLVMTypeConverter &getTypeConverter() const {
    return *static_cast<LLVMTypeConverter *>(
        OpConversionPattern<OpTy>::getTypeConverter());
  }

  LogicalResult
  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    Type elementType = getSrcVectorElementType<OpTy>(op);
    unsigned bitwidth = elementType.getIntOrFloatBitWidth();
    if (bitwidth == 32)
      return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
                                           adaptor.getOperands(),
                                           getTypeConverter(), rewriter);
    if (bitwidth == 64)
      return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
                                           adaptor.getOperands(),
                                           getTypeConverter(), rewriter);
    return rewriter.notifyMatchFailure(
        op, "expected 'src' to be either f32 or f64");
  }
};

struct MaskCompressOpConversion
    : public ConvertOpToLLVMPattern<MaskCompressOp> {
  using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto opType = adaptor.a().getType();

    Value src;
    if (op.src()) {
      src = adaptor.src();
    } else if (op.constant_src()) {
      src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
                                               op.constant_srcAttr());
    } else {
      Attribute zeroAttr = rewriter.getZeroAttr(opType);
      src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
    }

    rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
                                                    src, adaptor.k());

    return success();
  }
};

struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
  using ConvertOpToLLVMPattern<RsqrtOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto opType = adaptor.a().getType();
    rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
    return success();
  }
};

struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
  using ConvertOpToLLVMPattern<DotOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(DotOp op, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto opType = adaptor.a().getType();
    Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
    // Dot product of all elements, broadcasted to all elements.
    auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
    Value scale =
        rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
    rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
                                           scale);
    return success();
  }
};

/// An entry associating the "main" AVX512 op with its instantiations for
/// vectors of 32-bit and 64-bit elements.
template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
struct RegEntry {
  using MainOp = OpTy;
  using Intr32Op = Intr32OpTy;
  using Intr64Op = Intr64OpTy;
};

/// A container for op association entries facilitating the configuration of
/// dialect conversion.
template <typename... Args>
struct RegistryImpl {
  /// Registers the patterns specializing the "main" op to one of the
  /// "intrinsic" ops depending on elemental type.
  static void registerPatterns(LLVMTypeConverter &converter,
                               RewritePatternSet &patterns) {
    patterns
        .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
                              typename Args::Intr64Op>...>(converter);
  }

  /// Configures the conversion target to lower out "main" ops.
  static void configureTarget(LLVMConversionTarget &target) {
    target.addIllegalOp<typename Args::MainOp...>();
    target.addLegalOp<typename Args::Intr32Op...>();
    target.addLegalOp<typename Args::Intr64Op...>();
  }
};

using Registry = RegistryImpl<
    RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
    RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
    RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;

} // namespace

/// Populate the given list with patterns that convert from X86Vector to LLVM.
void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
    LLVMTypeConverter &converter, RewritePatternSet &patterns) {
  Registry::registerPatterns(converter, patterns);
  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
      converter);
}

void mlir::configureX86VectorLegalizeForExportTarget(
    LLVMConversionTarget &target) {
  Registry::configureTarget(target);
  target.addLegalOp<MaskCompressIntrOp>();
  target.addIllegalOp<MaskCompressOp>();
  target.addLegalOp<RsqrtIntrOp>();
  target.addIllegalOp<RsqrtOp>();
  target.addLegalOp<DotIntrOp>();
  target.addIllegalOp<DotOp>();
}
