1*ac0a70f3SAlex Zinenko //===- IRModule.cpp - IR pybind module ------------------------------------===//
2*ac0a70f3SAlex Zinenko //
3*ac0a70f3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*ac0a70f3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*ac0a70f3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*ac0a70f3SAlex Zinenko //
7*ac0a70f3SAlex Zinenko //===----------------------------------------------------------------------===//
8*ac0a70f3SAlex Zinenko 
9*ac0a70f3SAlex Zinenko #include "IRModule.h"
10*ac0a70f3SAlex Zinenko #include "Globals.h"
11*ac0a70f3SAlex Zinenko #include "PybindUtils.h"
12*ac0a70f3SAlex Zinenko 
13*ac0a70f3SAlex Zinenko #include <vector>
14*ac0a70f3SAlex Zinenko 
15*ac0a70f3SAlex Zinenko namespace py = pybind11;
16*ac0a70f3SAlex Zinenko using namespace mlir;
17*ac0a70f3SAlex Zinenko using namespace mlir::python;
18*ac0a70f3SAlex Zinenko 
19*ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
20*ac0a70f3SAlex Zinenko // PyGlobals
21*ac0a70f3SAlex Zinenko // -----------------------------------------------------------------------------
22*ac0a70f3SAlex Zinenko 
23*ac0a70f3SAlex Zinenko PyGlobals *PyGlobals::instance = nullptr;
24*ac0a70f3SAlex Zinenko 
25*ac0a70f3SAlex Zinenko PyGlobals::PyGlobals() {
26*ac0a70f3SAlex Zinenko   assert(!instance && "PyGlobals already constructed");
27*ac0a70f3SAlex Zinenko   instance = this;
28*ac0a70f3SAlex Zinenko }
29*ac0a70f3SAlex Zinenko 
30*ac0a70f3SAlex Zinenko PyGlobals::~PyGlobals() { instance = nullptr; }
31*ac0a70f3SAlex Zinenko 
32*ac0a70f3SAlex Zinenko void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) {
33*ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
34*ac0a70f3SAlex Zinenko   if (loadedDialectModulesCache.contains(dialectNamespace))
35*ac0a70f3SAlex Zinenko     return;
36*ac0a70f3SAlex Zinenko   // Since re-entrancy is possible, make a copy of the search prefixes.
37*ac0a70f3SAlex Zinenko   std::vector<std::string> localSearchPrefixes = dialectSearchPrefixes;
38*ac0a70f3SAlex Zinenko   py::object loaded;
39*ac0a70f3SAlex Zinenko   for (std::string moduleName : localSearchPrefixes) {
40*ac0a70f3SAlex Zinenko     moduleName.push_back('.');
41*ac0a70f3SAlex Zinenko     moduleName.append(dialectNamespace.data(), dialectNamespace.size());
42*ac0a70f3SAlex Zinenko 
43*ac0a70f3SAlex Zinenko     try {
44*ac0a70f3SAlex Zinenko       py::gil_scoped_release();
45*ac0a70f3SAlex Zinenko       loaded = py::module::import(moduleName.c_str());
46*ac0a70f3SAlex Zinenko     } catch (py::error_already_set &e) {
47*ac0a70f3SAlex Zinenko       if (e.matches(PyExc_ModuleNotFoundError)) {
48*ac0a70f3SAlex Zinenko         continue;
49*ac0a70f3SAlex Zinenko       } else {
50*ac0a70f3SAlex Zinenko         throw;
51*ac0a70f3SAlex Zinenko       }
52*ac0a70f3SAlex Zinenko     }
53*ac0a70f3SAlex Zinenko     break;
54*ac0a70f3SAlex Zinenko   }
55*ac0a70f3SAlex Zinenko 
56*ac0a70f3SAlex Zinenko   // Note: Iterator cannot be shared from prior to loading, since re-entrancy
57*ac0a70f3SAlex Zinenko   // may have occurred, which may do anything.
58*ac0a70f3SAlex Zinenko   loadedDialectModulesCache.insert(dialectNamespace);
59*ac0a70f3SAlex Zinenko }
60*ac0a70f3SAlex Zinenko 
61*ac0a70f3SAlex Zinenko void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
62*ac0a70f3SAlex Zinenko                                     py::object pyClass) {
63*ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
64*ac0a70f3SAlex Zinenko   py::object &found = dialectClassMap[dialectNamespace];
65*ac0a70f3SAlex Zinenko   if (found) {
66*ac0a70f3SAlex Zinenko     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") +
67*ac0a70f3SAlex Zinenko                                              dialectNamespace +
68*ac0a70f3SAlex Zinenko                                              "' is already registered.");
69*ac0a70f3SAlex Zinenko   }
70*ac0a70f3SAlex Zinenko   found = std::move(pyClass);
71*ac0a70f3SAlex Zinenko }
72*ac0a70f3SAlex Zinenko 
73*ac0a70f3SAlex Zinenko void PyGlobals::registerOperationImpl(const std::string &operationName,
74*ac0a70f3SAlex Zinenko                                       py::object pyClass,
75*ac0a70f3SAlex Zinenko                                       py::object rawOpViewClass) {
76*ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
77*ac0a70f3SAlex Zinenko   py::object &found = operationClassMap[operationName];
78*ac0a70f3SAlex Zinenko   if (found) {
79*ac0a70f3SAlex Zinenko     throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") +
80*ac0a70f3SAlex Zinenko                                              operationName +
81*ac0a70f3SAlex Zinenko                                              "' is already registered.");
82*ac0a70f3SAlex Zinenko   }
83*ac0a70f3SAlex Zinenko   found = std::move(pyClass);
84*ac0a70f3SAlex Zinenko   rawOpViewClassMap[operationName] = std::move(rawOpViewClass);
85*ac0a70f3SAlex Zinenko }
86*ac0a70f3SAlex Zinenko 
87*ac0a70f3SAlex Zinenko llvm::Optional<py::object>
88*ac0a70f3SAlex Zinenko PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
89*ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
90*ac0a70f3SAlex Zinenko   loadDialectModule(dialectNamespace);
91*ac0a70f3SAlex Zinenko   // Fast match against the class map first (common case).
92*ac0a70f3SAlex Zinenko   const auto foundIt = dialectClassMap.find(dialectNamespace);
93*ac0a70f3SAlex Zinenko   if (foundIt != dialectClassMap.end()) {
94*ac0a70f3SAlex Zinenko     if (foundIt->second.is_none())
95*ac0a70f3SAlex Zinenko       return llvm::None;
96*ac0a70f3SAlex Zinenko     assert(foundIt->second && "py::object is defined");
97*ac0a70f3SAlex Zinenko     return foundIt->second;
98*ac0a70f3SAlex Zinenko   }
99*ac0a70f3SAlex Zinenko 
100*ac0a70f3SAlex Zinenko   // Not found and loading did not yield a registration. Negative cache.
101*ac0a70f3SAlex Zinenko   dialectClassMap[dialectNamespace] = py::none();
102*ac0a70f3SAlex Zinenko   return llvm::None;
103*ac0a70f3SAlex Zinenko }
104*ac0a70f3SAlex Zinenko 
105*ac0a70f3SAlex Zinenko llvm::Optional<pybind11::object>
106*ac0a70f3SAlex Zinenko PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) {
107*ac0a70f3SAlex Zinenko   {
108*ac0a70f3SAlex Zinenko     py::gil_scoped_acquire();
109*ac0a70f3SAlex Zinenko     auto foundIt = rawOpViewClassMapCache.find(operationName);
110*ac0a70f3SAlex Zinenko     if (foundIt != rawOpViewClassMapCache.end()) {
111*ac0a70f3SAlex Zinenko       if (foundIt->second.is_none())
112*ac0a70f3SAlex Zinenko         return llvm::None;
113*ac0a70f3SAlex Zinenko       assert(foundIt->second && "py::object is defined");
114*ac0a70f3SAlex Zinenko       return foundIt->second;
115*ac0a70f3SAlex Zinenko     }
116*ac0a70f3SAlex Zinenko   }
117*ac0a70f3SAlex Zinenko 
118*ac0a70f3SAlex Zinenko   // Not found. Load the dialect namespace.
119*ac0a70f3SAlex Zinenko   auto split = operationName.split('.');
120*ac0a70f3SAlex Zinenko   llvm::StringRef dialectNamespace = split.first;
121*ac0a70f3SAlex Zinenko   loadDialectModule(dialectNamespace);
122*ac0a70f3SAlex Zinenko 
123*ac0a70f3SAlex Zinenko   // Attempt to find from the canonical map and cache.
124*ac0a70f3SAlex Zinenko   {
125*ac0a70f3SAlex Zinenko     py::gil_scoped_acquire();
126*ac0a70f3SAlex Zinenko     auto foundIt = rawOpViewClassMap.find(operationName);
127*ac0a70f3SAlex Zinenko     if (foundIt != rawOpViewClassMap.end()) {
128*ac0a70f3SAlex Zinenko       if (foundIt->second.is_none())
129*ac0a70f3SAlex Zinenko         return llvm::None;
130*ac0a70f3SAlex Zinenko       assert(foundIt->second && "py::object is defined");
131*ac0a70f3SAlex Zinenko       // Positive cache.
132*ac0a70f3SAlex Zinenko       rawOpViewClassMapCache[operationName] = foundIt->second;
133*ac0a70f3SAlex Zinenko       return foundIt->second;
134*ac0a70f3SAlex Zinenko     } else {
135*ac0a70f3SAlex Zinenko       // Negative cache.
136*ac0a70f3SAlex Zinenko       rawOpViewClassMap[operationName] = py::none();
137*ac0a70f3SAlex Zinenko       return llvm::None;
138*ac0a70f3SAlex Zinenko     }
139*ac0a70f3SAlex Zinenko   }
140*ac0a70f3SAlex Zinenko }
141*ac0a70f3SAlex Zinenko 
142*ac0a70f3SAlex Zinenko void PyGlobals::clearImportCache() {
143*ac0a70f3SAlex Zinenko   py::gil_scoped_acquire();
144*ac0a70f3SAlex Zinenko   loadedDialectModulesCache.clear();
145*ac0a70f3SAlex Zinenko   rawOpViewClassMapCache.clear();
146*ac0a70f3SAlex Zinenko }
147