//===- IRModule.cpp - IR pybind module ------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "IRModule.h"
#include "Globals.h"
#include "PybindUtils.h"

#include <vector>

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;

// -----------------------------------------------------------------------------
// PyGlobals
// -----------------------------------------------------------------------------

PyGlobals *PyGlobals::instance = nullptr;

PyGlobals::PyGlobals() {
  assert(!instance && "PyGlobals already constructed");
  instance = this;
}

PyGlobals::~PyGlobals() { instance = nullptr; }

void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
  py::gil_scoped_acquire();
  if (loadedDialectModulesCache.contains(dialectNamespace))
    return;
  // Since re-entrancy is possible, make a copy of the search prefixes.
  std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
  py::object loaded;
  for (std::string moduleName : localSearchPrefixes) {
    moduleName.push_back('.');
    moduleName.append(dialectNamespace.data(), dialectNamespace.size());

    try {
      py::gil_scoped_release();
      loaded = py::module::import(moduleName.c_str());
    } catch (py::error_already_set &e) {
      if (e.matches(PyExc_ModuleNotFoundError)) {
        continue;
      } else {
        throw;
      }
    }
    break;
  }

  // Note: Iterator cannot be shared from prior to loading, since re-entrancy
  // may have occurred, which may do anything.
  loadedDialectModulesCache.insert(dialectNamespace);
}

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
                                    py::object pyClass) {
  py::gil_scoped_acquire();
  py::object &found = dialectClassMap[dialectNamespace];
  if (found) {
    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
                                             dialectNamespace +
                                             "' is already registered.");
  }
  found = std::move(pyClass);
}

void PyGlobals::registerOperationImpl(const std::string &operationName,
                                      py::object pyClass,
                                      py::object rawOpViewClass) {
  py::gil_scoped_acquire();
  py::object &found = operationClassMap[operationName];
  if (found) {
    throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
                                             operationName +
                                             "' is already registered.");
  }
  found = std::move(pyClass);
  rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
}

llvm::Optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
  py::gil_scoped_acquire();
  loadDialectModule(dialectNamespace);
  // Fast match against the class map first (common case).
  const auto foundIt = dialectClassMap.find(dialectNamespace);
  if (foundIt != dialectClassMap.end()) {
    if (foundIt->second.is_none())
      return llvm::None;
    assert(foundIt->second && "py::object is defined");
    return foundIt->second;
  }

  // Not found and loading did not yield a registration. Negative cache.
  dialectClassMap[dialectNamespace] = py::none();
  return llvm::None;
}

llvm::Optional<pybind11::object>
PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
  {
    py::gil_scoped_acquire();
    auto foundIt = rawOpViewClassMapCache.find(operationName);
    if (foundIt != rawOpViewClassMapCache.end()) {
      if (foundIt->second.is_none())
        return llvm::None;
      assert(foundIt->second && "py::object is defined");
      return foundIt->second;
    }
  }

  // Not found. Load the dialect namespace.
  auto split = operationName.split('.');
  llvm::StringRef dialectNamespace = split.first;
  loadDialectModule(dialectNamespace);

  // Attempt to find from the canonical map and cache.
  {
    py::gil_scoped_acquire();
    auto foundIt = rawOpViewClassMap.find(operationName);
    if (foundIt != rawOpViewClassMap.end()) {
      if (foundIt->second.is_none())
        return llvm::None;
      assert(foundIt->second && "py::object is defined");
      // Positive cache.
      rawOpViewClassMapCache[operationName] = foundIt->second;
      return foundIt->second;
    } else {
      // Negative cache.
      rawOpViewClassMap[operationName] = py::none();
      return llvm::None;
    }
  }
}

void PyGlobals::clearImportCache() {
  py::gil_scoped_acquire();
  loadedDialectModulesCache.clear();
  rawOpViewClassMapCache.clear();
}
