//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements passes to convert `gpu.launch_func` op into a sequence
// of LLVM calls that emulate the host and device sides.
//
//===----------------------------------------------------------------------===//

#include "../PassDetail.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"

using namespace mlir;

static constexpr const char kSPIRVModule[] = "__spv__";

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Returns the string name of the `DescriptorSet` decoration.
static std::string descriptorSetName() {
  return llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::DescriptorSet));
}

/// Returns the string name of the `Binding` decoration.
static std::string bindingName() {
  return llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::Binding));
}

/// Calculates the index of the kernel's operand that is represented by the
/// given global variable with the `bind` attribute. We assume that the index of
/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
///   i -> (0, i)
/// which is implemented under `LowerABIAttributesPass`.
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
  return binding.getInt();
}

/// Copies the given number of bytes from src to dst pointers.
static void copy(Location loc, Value dst, Value src, Value size,
                 OpBuilder &builder) {
  MLIRContext *context = builder.getContext();
  auto llvmI1Type = IntegerType::get(context, 1);
  Value isVolatile = builder.create<LLVM::ConstantOp>(
      loc, llvmI1Type, builder.getBoolAttr(false));
  builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
}

/// Encodes the binding and descriptor set numbers into a new symbolic name.
/// The name is specified by
///   {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
/// binding numbers.
static std::string
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
                                 StringRef kernelModuleName) {
  IntegerAttr descriptorSet =
      op->getAttrOfType<IntegerAttr>(descriptorSetName());
  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
  return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
                       kernelModuleName.str(), op.sym_name().str(),
                       std::to_string(descriptorSet.getInt()),
                       std::to_string(binding.getInt()));
}

/// Returns true if the given global variable has both a descriptor set number
/// and a binding number.
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
  IntegerAttr descriptorSet =
      op->getAttrOfType<IntegerAttr>(descriptorSetName());
  IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
  return descriptorSet && binding;
}

/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
/// arguments from the given SPIR-V module. We assume that the module contains a
/// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
/// attribute are kernel arguments.
static LogicalResult getKernelGlobalVariables(
    spirv::ModuleOp module,
    DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
  auto entryPoints = module.getOps<spirv::EntryPointOp>();
  if (!llvm::hasSingleElement(entryPoints)) {
    return module.emitError(
        "The module must contain exactly one entry point function");
  }
  auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
  for (auto globalOp : globalVariables) {
    if (hasDescriptorSetAndBinding(globalOp))
      globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
  }
  return success();
}

/// Encodes the SPIR-V module's symbolic name into the name of the entry point
/// function.
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
  StringRef spvModuleName = module.sym_name().getValue();
  // We already know that the module contains exactly one entry point function
  // based on `getKernelGlobalVariables()` call. Update this function's name
  // to:
  //   {spv_module_name}_{function_name}
  auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
  StringRef funcName = entryPoint.fn();
  auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
  StringAttr newFuncName =
      StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
  if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
    return failure();
  SymbolTable::setSymbolName(funcOp, newFuncName);
  return success();
}

//===----------------------------------------------------------------------===//
// Conversion patterns
//===----------------------------------------------------------------------===//

namespace {

/// Structure to group information about the variables being copied.
struct CopyInfo {
  Value dst;
  Value src;
  Value size;
};

/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
/// copy the data to the global variable (emulating device side), call the
/// kernel as a normal void LLVM function, and copy the data back (emulating the
/// host side).
class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
  using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;

  LogicalResult
  matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
                  ConversionPatternRewriter &rewriter) const override {
    auto *op = launchOp.getOperation();
    MLIRContext *context = rewriter.getContext();
    auto module = launchOp->getParentOfType<ModuleOp>();

    // Get the SPIR-V module that represents the gpu kernel module. The module
    // is named:
    //   __spv__{kernel_module_name}
    // based on GPU to SPIR-V conversion.
    StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
    std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
    auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
        StringAttr::get(context, spvModuleName));
    if (!spvModule) {
      return launchOp.emitOpError("SPIR-V kernel module '")
             << spvModuleName << "' is not found";
    }

    // Declare kernel function in the main module so that it later can be linked
    // with its definition from the kernel module. We know that the kernel
    // function would have no arguments and the data is passed via global
    // variables. The name of the kernel will be
    //   {spv_module_name}_{kernel_function_name}
    // to avoid symbolic name conflicts.
    StringRef kernelFuncName = launchOp.getKernelName().getValue();
    std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
    auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
        StringAttr::get(context, newKernelFuncName));
    if (!kernelFunc) {
      OpBuilder::InsertionGuard guard(rewriter);
      rewriter.setInsertionPointToStart(module.getBody());
      kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
          rewriter.getUnknownLoc(), newKernelFuncName,
          LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
                                      ArrayRef<Type>()));
      rewriter.setInsertionPoint(launchOp);
    }

    // Get all global variables associated with the kernel operands.
    DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
    if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
      return failure();

    // Traverse kernel operands that were converted to MemRefDescriptors. For
    // each operand, create a global variable and copy data from operand to it.
    Location loc = launchOp.getLoc();
    SmallVector<CopyInfo, 4> copyInfo;
    auto numKernelOperands = launchOp.getNumKernelOperands();
    auto kernelOperands = operands.take_back(numKernelOperands);
    for (auto operand : llvm::enumerate(kernelOperands)) {
      // Check if the kernel's operand is a ranked memref.
      auto memRefType = launchOp.getKernelOperand(operand.index())
                            .getType()
                            .dyn_cast<MemRefType>();
      if (!memRefType)
        return failure();

      // Calculate the size of the memref and get the pointer to the allocated
      // buffer.
      SmallVector<Value, 4> sizes;
      SmallVector<Value, 4> strides;
      Value sizeBytes;
      getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
                               sizeBytes);
      MemRefDescriptor descriptor(operand.value());
      Value src = descriptor.allocatedPtr(rewriter, loc);

      // Get the global variable in the SPIR-V module that is associated with
      // the kernel operand. Construct its new name and create a corresponding
      // LLVM dialect global variable.
      spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
      auto pointeeType =
          spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
      auto dstGlobalType = typeConverter->convertType(pointeeType);
      if (!dstGlobalType)
        return failure();
      std::string name =
          createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
      // Check if this variable has already been created.
      auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
      if (!dstGlobal) {
        OpBuilder::InsertionGuard guard(rewriter);
        rewriter.setInsertionPointToStart(module.getBody());
        dstGlobal = rewriter.create<LLVM::GlobalOp>(
            loc, dstGlobalType,
            /*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
            /*alignment=*/0);
        rewriter.setInsertionPoint(launchOp);
      }

      // Copy the data from src operand pointer to dst global variable. Save
      // src, dst and size so that we can copy data back after emulating the
      // kernel call.
      Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
      copy(loc, dst, src, sizeBytes, rewriter);

      CopyInfo info;
      info.dst = dst;
      info.src = src;
      info.size = sizeBytes;
      copyInfo.push_back(info);
    }
    // Create a call to the kernel and copy the data back.
    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
                                              ArrayRef<Value>());
    for (CopyInfo info : copyInfo)
      copy(loc, info.src, info.dst, info.size, rewriter);
    return success();
  }
};

class LowerHostCodeToLLVM
    : public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
public:
  void runOnOperation() override {
    ModuleOp module = getOperation();

    // Erase the GPU module.
    for (auto gpuModule :
         llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
      gpuModule.erase();

    // Specify options to lower Standard to LLVM and pull in the conversion
    // patterns.
    LowerToLLVMOptions options(module.getContext());
    options.emitCWrappers = true;
    auto *context = module.getContext();
    RewritePatternSet patterns(context);
    LLVMTypeConverter typeConverter(context, options);
    populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
    populateStdToLLVMConversionPatterns(typeConverter, patterns);
    patterns.add<GPULaunchLowering>(typeConverter);

    // Pull in SPIR-V type conversion patterns to convert SPIR-V global
    // variable's type to LLVM dialect type.
    populateSPIRVToLLVMTypeConversion(typeConverter);

    ConversionTarget target(*context);
    target.addLegalDialect<LLVM::LLVMDialect>();
    if (failed(applyPartialConversion(module, target, std::move(patterns))))
      signalPassFailure();

    // Finally, modify the kernel function in SPIR-V modules to avoid symbolic
    // conflicts.
    for (auto spvModule : module.getOps<spirv::ModuleOp>())
      (void)encodeKernelName(spvModule);
  }
};
} // namespace

std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
mlir::createLowerHostCodeToLLVMPass() {
  return std::make_unique<LowerHostCodeToLLVM>();
}
