1ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===//
2ac0a70f3SAlex Zinenko //
3ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6ac0a70f3SAlex Zinenko //
7ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===//
8ac0a70f3SAlex Zinenko 
9ac0a70f3SAlex Zinenko #include "IRModule.h"
10ac0a70f3SAlex Zinenko #include "Globals.h"
11ac0a70f3SAlex Zinenko #include "PybindUtils.h"
12ac0a70f3SAlex Zinenko 
13ac0a70f3SAlex Zinenko #include <vector>
14ac0a70f3SAlex Zinenko 
15*cb7b0381SStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
16*cb7b0381SStella Laurenzo 
17ac0a70f3SAlex Zinenko namespace py = pybind11;
18ac0a70f3SAlex Zinenko using namespace mlir;
19ac0a70f3SAlex Zinenko using namespace mlir::python;
20ac0a70f3SAlex Zinenko 
21ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
22ac0a70f3SAlex Zinenko // PyGlobals
23ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
24ac0a70f3SAlex Zinenko 
25ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr;
26ac0a70f3SAlex Zinenko 
27ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() {
28ac0a70f3SAlex Zinenko   assert(!instance && "PyGlobals already constructed");
29ac0a70f3SAlex Zinenko   instance = this;
30*cb7b0381SStella Laurenzo   // The default search path include {mlir.}dialects, where {mlir.} is the
31*cb7b0381SStella Laurenzo   // package prefix configured at compile time.
32*cb7b0381SStella Laurenzo   dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects"));
33ac0a70f3SAlex Zinenko }
34ac0a70f3SAlex Zinenko 
35ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; }
36ac0a70f3SAlex Zinenko 
37ac0a70f3SAlex Zinenko void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
38ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
39ac0a70f3SAlex Zinenko   if (loadedDialectModulesCache.contains(dialectNamespace))
40ac0a70f3SAlex Zinenko     return;
41ac0a70f3SAlex Zinenko   // Since re-entrancy is possible, make a copy of the search prefixes.
42ac0a70f3SAlex Zinenko   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
43ac0a70f3SAlex Zinenko   py::object loaded;
44ac0a70f3SAlex Zinenko   for (std::string moduleName : localSearchPrefixes) {
45ac0a70f3SAlex Zinenko     moduleName.push_back('.');
46ac0a70f3SAlex Zinenko     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
47ac0a70f3SAlex Zinenko 
48ac0a70f3SAlex Zinenko     try {
49ac0a70f3SAlex Zinenko       py::gil_scoped_release();
50ac0a70f3SAlex Zinenko       loaded = py::module::import(moduleName.c_str());
51ac0a70f3SAlex Zinenko     } catch (py::error_already_set &e) {
52ac0a70f3SAlex Zinenko       if (e.matches(PyExc_ModuleNotFoundError)) {
53ac0a70f3SAlex Zinenko         continue;
54ac0a70f3SAlex Zinenko       } else {
55ac0a70f3SAlex Zinenko         throw;
56ac0a70f3SAlex Zinenko       }
57ac0a70f3SAlex Zinenko     }
58ac0a70f3SAlex Zinenko     break;
59ac0a70f3SAlex Zinenko   }
60ac0a70f3SAlex Zinenko 
61ac0a70f3SAlex Zinenko   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
62ac0a70f3SAlex Zinenko   // may have occurred, which may do anything.
63ac0a70f3SAlex Zinenko   loadedDialectModulesCache.insert(dialectNamespace);
64ac0a70f3SAlex Zinenko }
65ac0a70f3SAlex Zinenko 
66ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
67ac0a70f3SAlex Zinenko                                     py::object pyClass) {
68ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
69ac0a70f3SAlex Zinenko   py::object &found = dialectClassMap[dialectNamespace];
70ac0a70f3SAlex Zinenko   if (found) {
71ac0a70f3SAlex Zinenko     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
72ac0a70f3SAlex Zinenko                                              dialectNamespace +
73ac0a70f3SAlex Zinenko                                              "' is already registered.");
74ac0a70f3SAlex Zinenko   }
75ac0a70f3SAlex Zinenko   found = std::move(pyClass);
76ac0a70f3SAlex Zinenko }
77ac0a70f3SAlex Zinenko 
78ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName,
79ac0a70f3SAlex Zinenko                                       py::object pyClass,
80ac0a70f3SAlex Zinenko                                       py::object rawOpViewClass) {
81ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
82ac0a70f3SAlex Zinenko   py::object &found = operationClassMap[operationName];
83ac0a70f3SAlex Zinenko   if (found) {
84ac0a70f3SAlex Zinenko     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
85ac0a70f3SAlex Zinenko                                              operationName +
86ac0a70f3SAlex Zinenko                                              "' is already registered.");
87ac0a70f3SAlex Zinenko   }
88ac0a70f3SAlex Zinenko   found = std::move(pyClass);
89ac0a70f3SAlex Zinenko   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
90ac0a70f3SAlex Zinenko }
91ac0a70f3SAlex Zinenko 
92ac0a70f3SAlex Zinenko llvm::Optional<py::object>
93ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
94ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
95ac0a70f3SAlex Zinenko   loadDialectModule(dialectNamespace);
96ac0a70f3SAlex Zinenko   // Fast match against the class map first (common case).
97ac0a70f3SAlex Zinenko   const auto foundIt = dialectClassMap.find(dialectNamespace);
98ac0a70f3SAlex Zinenko   if (foundIt != dialectClassMap.end()) {
99ac0a70f3SAlex Zinenko     if (foundIt->second.is_none())
100ac0a70f3SAlex Zinenko       return llvm::None;
101ac0a70f3SAlex Zinenko     assert(foundIt->second && "py::object is defined");
102ac0a70f3SAlex Zinenko     return foundIt->second;
103ac0a70f3SAlex Zinenko   }
104ac0a70f3SAlex Zinenko 
105ac0a70f3SAlex Zinenko   // Not found and loading did not yield a registration. Negative cache.
106ac0a70f3SAlex Zinenko   dialectClassMap[dialectNamespace] = py::none();
107ac0a70f3SAlex Zinenko   return llvm::None;
108ac0a70f3SAlex Zinenko }
109ac0a70f3SAlex Zinenko 
110ac0a70f3SAlex Zinenko llvm::Optional<pybind11::object>
111ac0a70f3SAlex Zinenko PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
112ac0a70f3SAlex Zinenko   {
113ac0a70f3SAlex Zinenko     py::gil_scoped_acquire();
114ac0a70f3SAlex Zinenko     auto foundIt = rawOpViewClassMapCache.find(operationName);
115ac0a70f3SAlex Zinenko     if (foundIt != rawOpViewClassMapCache.end()) {
116ac0a70f3SAlex Zinenko       if (foundIt->second.is_none())
117ac0a70f3SAlex Zinenko         return llvm::None;
118ac0a70f3SAlex Zinenko       assert(foundIt->second && "py::object is defined");
119ac0a70f3SAlex Zinenko       return foundIt->second;
120ac0a70f3SAlex Zinenko     }
121ac0a70f3SAlex Zinenko   }
122ac0a70f3SAlex Zinenko 
123ac0a70f3SAlex Zinenko   // Not found. Load the dialect namespace.
124ac0a70f3SAlex Zinenko   auto split = operationName.split('.');
125ac0a70f3SAlex Zinenko   llvm::StringRef dialectNamespace = split.first;
126ac0a70f3SAlex Zinenko   loadDialectModule(dialectNamespace);
127ac0a70f3SAlex Zinenko 
128ac0a70f3SAlex Zinenko   // Attempt to find from the canonical map and cache.
129ac0a70f3SAlex Zinenko   {
130ac0a70f3SAlex Zinenko     py::gil_scoped_acquire();
131ac0a70f3SAlex Zinenko     auto foundIt = rawOpViewClassMap.find(operationName);
132ac0a70f3SAlex Zinenko     if (foundIt != rawOpViewClassMap.end()) {
133ac0a70f3SAlex Zinenko       if (foundIt->second.is_none())
134ac0a70f3SAlex Zinenko         return llvm::None;
135ac0a70f3SAlex Zinenko       assert(foundIt->second && "py::object is defined");
136ac0a70f3SAlex Zinenko       // Positive cache.
137ac0a70f3SAlex Zinenko       rawOpViewClassMapCache[operationName] = foundIt->second;
138ac0a70f3SAlex Zinenko       return foundIt->second;
139ac0a70f3SAlex Zinenko     } else {
140ac0a70f3SAlex Zinenko       // Negative cache.
141ac0a70f3SAlex Zinenko       rawOpViewClassMap[operationName] = py::none();
142ac0a70f3SAlex Zinenko       return llvm::None;
143ac0a70f3SAlex Zinenko     }
144ac0a70f3SAlex Zinenko   }
145ac0a70f3SAlex Zinenko }
146ac0a70f3SAlex Zinenko 
147ac0a70f3SAlex Zinenko void PyGlobals::clearImportCache() {
148ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
149ac0a70f3SAlex Zinenko   loadedDialectModulesCache.clear();
150ac0a70f3SAlex Zinenko   rawOpViewClassMapCache.clear();
151ac0a70f3SAlex Zinenko }
152