//===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the SPIR-V module combiner library.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"

using namespace mlir;

static constexpr unsigned maxFreeID = 1 << 20;

/// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
                               spirv::ModuleOp module) {
  SmallString<64> newSymName(oldSymName);
  newSymName.push_back('_');

  MLIRContext *ctx = module->getContext();

  while (lastUsedID < maxFreeID) {
    auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
    if (!SymbolTable::lookupSymbolIn(module, possible))
      return possible;
  }

  return StringAttr::get(ctx, newSymName);
}

/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
                                            spirv::ModuleOp target,
                                            spirv::ModuleOp source,
                                            unsigned &lastUsedID) {
  if (!SymbolTable::lookupSymbolIn(source, op.getName()))
    return success();

  StringRef oldSymName = op.getName();
  StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);

  if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
    return op.emitError("unable to update all symbol uses for ")
           << oldSymName << " to " << newSymName;

  SymbolTable::setSymbolName(op, newSymName);
  return success();
}

/// Computes a hash code to represent `symbolOp` based on all its attributes
/// except for the symbol name.
///
/// Note: We use the operation's name (not the symbol name) as part of the hash
/// computation. This prevents, for example, mistakenly considering a global
/// variable and a spec constant as duplicates because their descriptor set +
/// binding and spec_id, respectively, happen to hash to the same value.
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
  auto range =
      llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
        return attr.getName() != SymbolTable::getSymbolAttrName();
      });

  return llvm::hash_combine(
      symbolOp->getName(),
      llvm::hash_combine_range(range.begin(), range.end()));
}

namespace mlir {
namespace spirv {

OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
                                     OpBuilder &combinedModuleBuilder,
                                     SymbolRenameListener symRenameListener) {
  if (inputModules.empty())
    return nullptr;

  spirv::ModuleOp firstModule = inputModules.front();
  auto addressingModel = firstModule.addressing_model();
  auto memoryModel = firstModule.memory_model();
  auto vceTriple = firstModule.vce_triple();

  // First check whether there are conflicts between addressing/memory model.
  // Return early if so.
  for (auto module : inputModules) {
    if (module.addressing_model() != addressingModel ||
        module.memory_model() != memoryModel ||
        module.vce_triple() != vceTriple) {
      module.emitError("input modules differ in addressing model, memory "
                       "model, and/or VCE triple");
      return nullptr;
    }
  }

  auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
      firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
  combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());

  // In some cases, a symbol in the (current state of the) combined module is
  // renamed in order to enable the conflicting symbol in the input module
  // being merged. For example, if the conflict is between a global variable in
  // the current combined module and a function in the input module, the global
  // variable is renamed. In order to notify listeners of the symbol updates in
  // such cases, we need to keep track of the module from which the renamed
  // symbol in the combined module originated. This map keeps such information.
  llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;

  unsigned lastUsedID = 0;

  for (auto inputModule : inputModules) {
    OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();

    // In the combined module, rename all symbols that conflict with symbols
    // from the current input module. This renaming applies to all ops except
    // for spv.funcs. This way, if the conflicting op in the input module is
    // non-spv.func, we rename that symbol instead and maintain the spv.func in
    // the combined module name as it is.
    for (auto &op : *combinedModule.getBody()) {
      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
      if (!symbolOp)
        continue;

      StringRef oldSymName = symbolOp.getName();

      if (!isa<FuncOp>(op) &&
          failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
                                        lastUsedID)))
        return nullptr;

      StringRef newSymName = symbolOp.getName();

      if (symRenameListener && oldSymName != newSymName) {
        spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);

        if (!originalModule) {
          inputModule.emitError(
              "unable to find original spirv::ModuleOp for symbol ")
              << oldSymName;
          return nullptr;
        }

        symRenameListener(originalModule, oldSymName, newSymName);

        // Since the symbol name is updated, there is no need to maintain the
        // entry that associates the old symbol name with the original module.
        symNameToModuleMap.erase(oldSymName);
        // Instead, add a new entry to map the new symbol name to the original
        // module in case it gets renamed again later.
        symNameToModuleMap[newSymName] = originalModule;
      }
    }

    // In the current input module, rename all symbols that conflict with
    // symbols from the combined module. This includes renaming spv.funcs.
    for (auto &op : *moduleClone->getBody()) {
      auto symbolOp = dyn_cast<SymbolOpInterface>(op);
      if (!symbolOp)
        continue;

      StringRef oldSymName = symbolOp.getName();

      if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
                                        lastUsedID)))
        return nullptr;

      StringRef newSymName = symbolOp.getName();

      if (symRenameListener) {
        if (oldSymName != newSymName)
          symRenameListener(inputModule, oldSymName, newSymName);

        // Insert the module associated with the symbol name.
        auto emplaceResult =
            symNameToModuleMap.try_emplace(newSymName, inputModule);

        // If an entry with the same symbol name is already present, this must
        // be a problem with the implementation, specially clean-up of the map
        // while iterating over the combined module above.
        if (!emplaceResult.second) {
          inputModule.emitError("did not expect to find an entry for symbol ")
              << symbolOp.getName();
          return nullptr;
        }
      }
    }

    // Clone all the module's ops to the combined module.
    for (auto &op : *moduleClone->getBody())
      combinedModuleBuilder.insert(op.clone());
  }

  // Deduplicate identical global variables, spec constants, and functions.
  DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
  SmallVector<SymbolOpInterface, 0> eraseList;

  for (auto &op : *combinedModule.getBody()) {
    SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
    if (!symbolOp)
      continue;

    // Do not support ops with operands or results.
    // Global variables, spec constants, and functions won't have
    // operands/results, but just for safety here.
    if (op.getNumOperands() != 0 || op.getNumResults() != 0)
      continue;

    // Deduplicating functions are not supported yet.
    if (isa<FuncOp>(op))
      continue;

    auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
    if (result.second)
      continue;

    SymbolOpInterface replacementSymOp = result.first->second;

    if (failed(SymbolTable::replaceAllSymbolUses(
            symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
      symbolOp.emitError("unable to update all symbol uses for ")
          << symbolOp.getName() << " to " << replacementSymOp.getName();
      return nullptr;
    }

    eraseList.push_back(symbolOp);
  }

  for (auto symbolOp : eraseList)
    symbolOp.erase();

  return combinedModule;
}

} // namespace spirv
} // namespace mlir
