1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "Globals.h"
12436c6c9cSStella Laurenzo #include "PybindUtils.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
174acd8457SAlex Zinenko #include "mlir-c/Debug.h"
183f3d1c90SMike Urbach #include "mlir-c/IR.h"
195e83a5b4SStella Laurenzo //#include "mlir-c/Registration.h"
20e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h"
21436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
22436c6c9cSStella Laurenzo 
231fc096afSMehdi Amini #include <utility>
241fc096afSMehdi Amini 
25436c6c9cSStella Laurenzo namespace py = pybind11;
26436c6c9cSStella Laurenzo using namespace mlir;
27436c6c9cSStella Laurenzo using namespace mlir::python;
28436c6c9cSStella Laurenzo 
29436c6c9cSStella Laurenzo using llvm::SmallVector;
30436c6c9cSStella Laurenzo using llvm::StringRef;
31436c6c9cSStella Laurenzo using llvm::Twine;
32436c6c9cSStella Laurenzo 
33436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
34436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
35436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
36436c6c9cSStella Laurenzo 
37436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
38436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
39436c6c9cSStella Laurenzo 
40436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed.
41436c6c9cSStella Laurenzo 
42436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
43436c6c9cSStella Laurenzo )";
44436c6c9cSStella Laurenzo 
45e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] =
46e67cbbefSJacques Pienaar     R"(Gets a Location representing a caller and callsite)";
47e67cbbefSJacques Pienaar 
48436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
49436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
50436c6c9cSStella Laurenzo 
511ab3efacSJacques Pienaar static const char kContextGetFusedLocationDocstring[] =
521ab3efacSJacques Pienaar     R"(Gets a Location representing a fused location with optional metadata)";
531ab3efacSJacques Pienaar 
5404d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] =
5504d76d36SJacques Pienaar     R"(Gets a Location representing a named location with optional child location)";
5604d76d36SJacques Pienaar 
57436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
58436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
59436c6c9cSStella Laurenzo 
60436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails.
61436c6c9cSStella Laurenzo 
62436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
63436c6c9cSStella Laurenzo )";
64436c6c9cSStella Laurenzo 
65436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
66436c6c9cSStella Laurenzo     R"(Creates a new operation.
67436c6c9cSStella Laurenzo 
68436c6c9cSStella Laurenzo Args:
69436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
70436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
71436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
72436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
73436c6c9cSStella Laurenzo   regions: Number of regions to create.
74436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
75436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
76436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
77436c6c9cSStella Laurenzo     context manager).
78436c6c9cSStella Laurenzo Returns:
79436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
80436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
81436c6c9cSStella Laurenzo )";
82436c6c9cSStella Laurenzo 
83436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
84436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
85436c6c9cSStella Laurenzo 
86436c6c9cSStella Laurenzo Args:
87436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
88436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
89436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
90436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
91436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
92436c6c9cSStella Laurenzo     to False.
93436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
94436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
95436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
96436c6c9cSStella Laurenzo     ops. Defaults to False.
97436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
98436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
99436c6c9cSStella Laurenzo     module prints.
100ace1d0adSStella Laurenzo   assume_verified: By default, if not printing generic form, the verifier
101ace1d0adSStella Laurenzo     will be run and if it fails, generic form will be printed with a comment
102ace1d0adSStella Laurenzo     about failed verification. While a reasonable default for interactive use,
103ace1d0adSStella Laurenzo     for systematic use, it is often better for the caller to verify explicitly
104ace1d0adSStella Laurenzo     and report failures in a more robust fashion. Set this to True if doing this
105ace1d0adSStella Laurenzo     in order to avoid running a redundant verification. If the IR is actually
106ace1d0adSStella Laurenzo     invalid, behavior is undefined.
107436c6c9cSStella Laurenzo )";
108436c6c9cSStella Laurenzo 
109436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
110436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
111436c6c9cSStella Laurenzo 
112436c6c9cSStella Laurenzo Args:
113436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114436c6c9cSStella Laurenzo     False.
115436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
116436c6c9cSStella Laurenzo     configuring the printout.
117436c6c9cSStella Laurenzo Returns:
118436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
119436c6c9cSStella Laurenzo   argument.
120436c6c9cSStella Laurenzo )";
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
123436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
126436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
127436c6c9cSStella Laurenzo customize behavior.
128436c6c9cSStella Laurenzo )";
129436c6c9cSStella Laurenzo 
130436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
131436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
132436c6c9cSStella Laurenzo 
133436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
134436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
135436c6c9cSStella Laurenzo 
136436c6c9cSStella Laurenzo Returns:
137436c6c9cSStella Laurenzo   The created block.
138436c6c9cSStella Laurenzo )";
139436c6c9cSStella Laurenzo 
140436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
141436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
142436c6c9cSStella Laurenzo 
143436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
144436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
145436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
146436c6c9cSStella Laurenzo )";
147436c6c9cSStella Laurenzo 
148436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
149436c6c9cSStella Laurenzo // Utilities.
150436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
151436c6c9cSStella Laurenzo 
1524acd8457SAlex Zinenko /// Helper for creating an @classmethod.
153436c6c9cSStella Laurenzo template <class Func, typename... Args>
classmethod(Func f,Args...args)154436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) {
155436c6c9cSStella Laurenzo   py::object cf = py::cpp_function(f, args...);
156436c6c9cSStella Laurenzo   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157436c6c9cSStella Laurenzo }
158436c6c9cSStella Laurenzo 
159436c6c9cSStella Laurenzo static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)160436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
161436c6c9cSStella Laurenzo                            py::object dialectDescriptor) {
162436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163436c6c9cSStella Laurenzo   if (!dialectClass) {
164436c6c9cSStella Laurenzo     // Use the base class.
165436c6c9cSStella Laurenzo     return py::cast(PyDialect(std::move(dialectDescriptor)));
166436c6c9cSStella Laurenzo   }
167436c6c9cSStella Laurenzo 
168436c6c9cSStella Laurenzo   // Create the custom implementation.
169436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
170436c6c9cSStella Laurenzo }
171436c6c9cSStella Laurenzo 
toMlirStringRef(const std::string & s)172436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
173436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
174436c6c9cSStella Laurenzo }
175436c6c9cSStella Laurenzo 
1764acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag.
1774acd8457SAlex Zinenko struct PyGlobalDebugFlag {
setPyGlobalDebugFlag1784acd8457SAlex Zinenko   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
1794acd8457SAlex Zinenko 
getPyGlobalDebugFlag1801fc096afSMehdi Amini   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
1814acd8457SAlex Zinenko 
bindPyGlobalDebugFlag1824acd8457SAlex Zinenko   static void bind(py::module &m) {
1834acd8457SAlex Zinenko     // Debug flags.
184f05ff4f7SStella Laurenzo     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
1854acd8457SAlex Zinenko         .def_property_static("flag", &PyGlobalDebugFlag::get,
1864acd8457SAlex Zinenko                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
1874acd8457SAlex Zinenko   }
1884acd8457SAlex Zinenko };
1894acd8457SAlex Zinenko 
190436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
191436c6c9cSStella Laurenzo // Collections.
192436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
193436c6c9cSStella Laurenzo 
194436c6c9cSStella Laurenzo namespace {
195436c6c9cSStella Laurenzo 
196436c6c9cSStella Laurenzo class PyRegionIterator {
197436c6c9cSStella Laurenzo public:
PyRegionIterator(PyOperationRef operation)198436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
199436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
200436c6c9cSStella Laurenzo 
dunderIter()201436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
202436c6c9cSStella Laurenzo 
dunderNext()203436c6c9cSStella Laurenzo   PyRegion dunderNext() {
204436c6c9cSStella Laurenzo     operation->checkValid();
205436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206436c6c9cSStella Laurenzo       throw py::stop_iteration();
207436c6c9cSStella Laurenzo     }
208436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209436c6c9cSStella Laurenzo     return PyRegion(operation, region);
210436c6c9cSStella Laurenzo   }
211436c6c9cSStella Laurenzo 
bind(py::module & m)212436c6c9cSStella Laurenzo   static void bind(py::module &m) {
213f05ff4f7SStella Laurenzo     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
215436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
216436c6c9cSStella Laurenzo   }
217436c6c9cSStella Laurenzo 
218436c6c9cSStella Laurenzo private:
219436c6c9cSStella Laurenzo   PyOperationRef operation;
220436c6c9cSStella Laurenzo   int nextIndex = 0;
221436c6c9cSStella Laurenzo };
222436c6c9cSStella Laurenzo 
223436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
224436c6c9cSStella Laurenzo /// with a sequence-like container.
225436c6c9cSStella Laurenzo class PyRegionList {
226436c6c9cSStella Laurenzo public:
PyRegionList(PyOperationRef operation)227436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228436c6c9cSStella Laurenzo 
dunderLen()229436c6c9cSStella Laurenzo   intptr_t dunderLen() {
230436c6c9cSStella Laurenzo     operation->checkValid();
231436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
232436c6c9cSStella Laurenzo   }
233436c6c9cSStella Laurenzo 
dunderGetItem(intptr_t index)234436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
235436c6c9cSStella Laurenzo     // dunderLen checks validity.
236436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
237436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
238436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
239436c6c9cSStella Laurenzo     }
240436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241436c6c9cSStella Laurenzo     return PyRegion(operation, region);
242436c6c9cSStella Laurenzo   }
243436c6c9cSStella Laurenzo 
bind(py::module & m)244436c6c9cSStella Laurenzo   static void bind(py::module &m) {
245f05ff4f7SStella Laurenzo     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
247436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
248436c6c9cSStella Laurenzo   }
249436c6c9cSStella Laurenzo 
250436c6c9cSStella Laurenzo private:
251436c6c9cSStella Laurenzo   PyOperationRef operation;
252436c6c9cSStella Laurenzo };
253436c6c9cSStella Laurenzo 
254436c6c9cSStella Laurenzo class PyBlockIterator {
255436c6c9cSStella Laurenzo public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)256436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
258436c6c9cSStella Laurenzo 
dunderIter()259436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
260436c6c9cSStella Laurenzo 
dunderNext()261436c6c9cSStella Laurenzo   PyBlock dunderNext() {
262436c6c9cSStella Laurenzo     operation->checkValid();
263436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
264436c6c9cSStella Laurenzo       throw py::stop_iteration();
265436c6c9cSStella Laurenzo     }
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
268436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
269436c6c9cSStella Laurenzo     return returnBlock;
270436c6c9cSStella Laurenzo   }
271436c6c9cSStella Laurenzo 
bind(py::module & m)272436c6c9cSStella Laurenzo   static void bind(py::module &m) {
273f05ff4f7SStella Laurenzo     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
275436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
276436c6c9cSStella Laurenzo   }
277436c6c9cSStella Laurenzo 
278436c6c9cSStella Laurenzo private:
279436c6c9cSStella Laurenzo   PyOperationRef operation;
280436c6c9cSStella Laurenzo   MlirBlock next;
281436c6c9cSStella Laurenzo };
282436c6c9cSStella Laurenzo 
283436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
285436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
286436c6c9cSStella Laurenzo class PyBlockList {
287436c6c9cSStella Laurenzo public:
PyBlockList(PyOperationRef operation,MlirRegion region)288436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
289436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
290436c6c9cSStella Laurenzo 
dunderIter()291436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
292436c6c9cSStella Laurenzo     operation->checkValid();
293436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294436c6c9cSStella Laurenzo   }
295436c6c9cSStella Laurenzo 
dunderLen()296436c6c9cSStella Laurenzo   intptr_t dunderLen() {
297436c6c9cSStella Laurenzo     operation->checkValid();
298436c6c9cSStella Laurenzo     intptr_t count = 0;
299436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
300436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
301436c6c9cSStella Laurenzo       count += 1;
302436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
303436c6c9cSStella Laurenzo     }
304436c6c9cSStella Laurenzo     return count;
305436c6c9cSStella Laurenzo   }
306436c6c9cSStella Laurenzo 
dunderGetItem(intptr_t index)307436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
308436c6c9cSStella Laurenzo     operation->checkValid();
309436c6c9cSStella Laurenzo     if (index < 0) {
310436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
311436c6c9cSStella Laurenzo                        "attempt to access out of bounds block");
312436c6c9cSStella Laurenzo     }
313436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
314436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
315436c6c9cSStella Laurenzo       if (index == 0) {
316436c6c9cSStella Laurenzo         return PyBlock(operation, block);
317436c6c9cSStella Laurenzo       }
318436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
319436c6c9cSStella Laurenzo       index -= 1;
320436c6c9cSStella Laurenzo     }
321436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322436c6c9cSStella Laurenzo   }
323436c6c9cSStella Laurenzo 
appendBlock(const py::args & pyArgTypes)3241fc096afSMehdi Amini   PyBlock appendBlock(const py::args &pyArgTypes) {
325436c6c9cSStella Laurenzo     operation->checkValid();
326436c6c9cSStella Laurenzo     llvm::SmallVector<MlirType, 4> argTypes;
327e084679fSRiver Riddle     llvm::SmallVector<MlirLocation, 4> argLocs;
328436c6c9cSStella Laurenzo     argTypes.reserve(pyArgTypes.size());
329e084679fSRiver Riddle     argLocs.reserve(pyArgTypes.size());
330436c6c9cSStella Laurenzo     for (auto &pyArg : pyArgTypes) {
331436c6c9cSStella Laurenzo       argTypes.push_back(pyArg.cast<PyType &>());
332e084679fSRiver Riddle       // TODO: Pass in a proper location here.
333e084679fSRiver Riddle       argLocs.push_back(
334e084679fSRiver Riddle           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335436c6c9cSStella Laurenzo     }
336436c6c9cSStella Laurenzo 
337e084679fSRiver Riddle     MlirBlock block =
338e084679fSRiver Riddle         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
340436c6c9cSStella Laurenzo     return PyBlock(operation, block);
341436c6c9cSStella Laurenzo   }
342436c6c9cSStella Laurenzo 
bind(py::module & m)343436c6c9cSStella Laurenzo   static void bind(py::module &m) {
344f05ff4f7SStella Laurenzo     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
346436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
347436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
348436c6c9cSStella Laurenzo         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349436c6c9cSStella Laurenzo   }
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo private:
352436c6c9cSStella Laurenzo   PyOperationRef operation;
353436c6c9cSStella Laurenzo   MlirRegion region;
354436c6c9cSStella Laurenzo };
355436c6c9cSStella Laurenzo 
356436c6c9cSStella Laurenzo class PyOperationIterator {
357436c6c9cSStella Laurenzo public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)358436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
360436c6c9cSStella Laurenzo 
dunderIter()361436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
362436c6c9cSStella Laurenzo 
dunderNext()363436c6c9cSStella Laurenzo   py::object dunderNext() {
364436c6c9cSStella Laurenzo     parentOperation->checkValid();
365436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
366436c6c9cSStella Laurenzo       throw py::stop_iteration();
367436c6c9cSStella Laurenzo     }
368436c6c9cSStella Laurenzo 
369436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
370436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
371436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
372436c6c9cSStella Laurenzo     return returnOperation->createOpView();
373436c6c9cSStella Laurenzo   }
374436c6c9cSStella Laurenzo 
bind(py::module & m)375436c6c9cSStella Laurenzo   static void bind(py::module &m) {
376f05ff4f7SStella Laurenzo     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
378436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
379436c6c9cSStella Laurenzo   }
380436c6c9cSStella Laurenzo 
381436c6c9cSStella Laurenzo private:
382436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
383436c6c9cSStella Laurenzo   MlirOperation next;
384436c6c9cSStella Laurenzo };
385436c6c9cSStella Laurenzo 
386436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
387436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
388436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
389436c6c9cSStella Laurenzo /// by a block.
390436c6c9cSStella Laurenzo class PyOperationList {
391436c6c9cSStella Laurenzo public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)392436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
394436c6c9cSStella Laurenzo 
dunderIter()395436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
396436c6c9cSStella Laurenzo     parentOperation->checkValid();
397436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
398436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
399436c6c9cSStella Laurenzo   }
400436c6c9cSStella Laurenzo 
dunderLen()401436c6c9cSStella Laurenzo   intptr_t dunderLen() {
402436c6c9cSStella Laurenzo     parentOperation->checkValid();
403436c6c9cSStella Laurenzo     intptr_t count = 0;
404436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
406436c6c9cSStella Laurenzo       count += 1;
407436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
408436c6c9cSStella Laurenzo     }
409436c6c9cSStella Laurenzo     return count;
410436c6c9cSStella Laurenzo   }
411436c6c9cSStella Laurenzo 
dunderGetItem(intptr_t index)412436c6c9cSStella Laurenzo   py::object dunderGetItem(intptr_t index) {
413436c6c9cSStella Laurenzo     parentOperation->checkValid();
414436c6c9cSStella Laurenzo     if (index < 0) {
415436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
416436c6c9cSStella Laurenzo                        "attempt to access out of bounds operation");
417436c6c9cSStella Laurenzo     }
418436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
420436c6c9cSStella Laurenzo       if (index == 0) {
421436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422436c6c9cSStella Laurenzo             ->createOpView();
423436c6c9cSStella Laurenzo       }
424436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
425436c6c9cSStella Laurenzo       index -= 1;
426436c6c9cSStella Laurenzo     }
427436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError,
428436c6c9cSStella Laurenzo                      "attempt to access out of bounds operation");
429436c6c9cSStella Laurenzo   }
430436c6c9cSStella Laurenzo 
bind(py::module & m)431436c6c9cSStella Laurenzo   static void bind(py::module &m) {
432f05ff4f7SStella Laurenzo     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
434436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
435436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
436436c6c9cSStella Laurenzo   }
437436c6c9cSStella Laurenzo 
438436c6c9cSStella Laurenzo private:
439436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
440436c6c9cSStella Laurenzo   MlirBlock block;
441436c6c9cSStella Laurenzo };
442436c6c9cSStella Laurenzo 
443436c6c9cSStella Laurenzo } // namespace
444436c6c9cSStella Laurenzo 
445436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
446436c6c9cSStella Laurenzo // PyMlirContext
447436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
448436c6c9cSStella Laurenzo 
PyMlirContext(MlirContext context)449436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
451436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
452436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
453436c6c9cSStella Laurenzo }
454436c6c9cSStella Laurenzo 
~PyMlirContext()455436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
456436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
457436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
458436c6c9cSStella Laurenzo   // liveContexts.
459436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
460436c6c9cSStella Laurenzo   getLiveContexts().erase(context.ptr);
461436c6c9cSStella Laurenzo   mlirContextDestroy(context);
462436c6c9cSStella Laurenzo }
463436c6c9cSStella Laurenzo 
getCapsule()464436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() {
465436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466436c6c9cSStella Laurenzo }
467436c6c9cSStella Laurenzo 
createFromCapsule(py::object capsule)468436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) {
469436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
471436c6c9cSStella Laurenzo     throw py::error_already_set();
472436c6c9cSStella Laurenzo   return forContext(rawContext).releaseObject();
473436c6c9cSStella Laurenzo }
474436c6c9cSStella Laurenzo 
createNewContextForInit()475436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() {
476436c6c9cSStella Laurenzo   MlirContext context = mlirContextCreate();
477436c6c9cSStella Laurenzo   return new PyMlirContext(context);
478436c6c9cSStella Laurenzo }
479436c6c9cSStella Laurenzo 
forContext(MlirContext context)480436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
481436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
482436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
483436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
484436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
485436c6c9cSStella Laurenzo     // Create.
486436c6c9cSStella Laurenzo     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
487436c6c9cSStella Laurenzo     py::object pyRef = py::cast(unownedContextWrapper);
488436c6c9cSStella Laurenzo     assert(pyRef && "cast to py::object failed");
489436c6c9cSStella Laurenzo     liveContexts[context.ptr] = unownedContextWrapper;
490436c6c9cSStella Laurenzo     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
491436c6c9cSStella Laurenzo   }
492436c6c9cSStella Laurenzo   // Use existing.
493436c6c9cSStella Laurenzo   py::object pyRef = py::cast(it->second);
494436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
495436c6c9cSStella Laurenzo }
496436c6c9cSStella Laurenzo 
getLiveContexts()497436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
498436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
499436c6c9cSStella Laurenzo   return liveContexts;
500436c6c9cSStella Laurenzo }
501436c6c9cSStella Laurenzo 
getLiveCount()502436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
503436c6c9cSStella Laurenzo 
getLiveOperationCount()504436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
505436c6c9cSStella Laurenzo 
clearLiveOperations()5066b0bed7eSJohn Demme size_t PyMlirContext::clearLiveOperations() {
5076b0bed7eSJohn Demme   for (auto &op : liveOperations)
5086b0bed7eSJohn Demme     op.second.second->setInvalid();
5096b0bed7eSJohn Demme   size_t numInvalidated = liveOperations.size();
5106b0bed7eSJohn Demme   liveOperations.clear();
5116b0bed7eSJohn Demme   return numInvalidated;
5126b0bed7eSJohn Demme }
5136b0bed7eSJohn Demme 
getLiveModuleCount()514436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
515436c6c9cSStella Laurenzo 
contextEnter()516436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() {
517436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushContext(*this);
518436c6c9cSStella Laurenzo }
519436c6c9cSStella Laurenzo 
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)5201fc096afSMehdi Amini void PyMlirContext::contextExit(const pybind11::object &excType,
5211fc096afSMehdi Amini                                 const pybind11::object &excVal,
5221fc096afSMehdi Amini                                 const pybind11::object &excTb) {
523436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
524436c6c9cSStella Laurenzo }
525436c6c9cSStella Laurenzo 
attachDiagnosticHandler(py::object callback)5267ee25bc5SStella Laurenzo py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
5277ee25bc5SStella Laurenzo   // Note that ownership is transferred to the delete callback below by way of
5287ee25bc5SStella Laurenzo   // an explicit inc_ref (borrow).
5297ee25bc5SStella Laurenzo   PyDiagnosticHandler *pyHandler =
5307ee25bc5SStella Laurenzo       new PyDiagnosticHandler(get(), std::move(callback));
5317ee25bc5SStella Laurenzo   py::object pyHandlerObject =
5327ee25bc5SStella Laurenzo       py::cast(pyHandler, py::return_value_policy::take_ownership);
5337ee25bc5SStella Laurenzo   pyHandlerObject.inc_ref();
5347ee25bc5SStella Laurenzo 
5357ee25bc5SStella Laurenzo   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
5367ee25bc5SStella Laurenzo   // guaranteed to be known to pybind.
5377ee25bc5SStella Laurenzo   auto handlerCallback =
5387ee25bc5SStella Laurenzo       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
5397ee25bc5SStella Laurenzo     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
5407ee25bc5SStella Laurenzo     py::object pyDiagnosticObject =
5417ee25bc5SStella Laurenzo         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
5427ee25bc5SStella Laurenzo 
5437ee25bc5SStella Laurenzo     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
5447ee25bc5SStella Laurenzo     bool result = false;
5457ee25bc5SStella Laurenzo     {
5467ee25bc5SStella Laurenzo       // Since this can be called from arbitrary C++ contexts, always get the
5477ee25bc5SStella Laurenzo       // gil.
5487ee25bc5SStella Laurenzo       py::gil_scoped_acquire gil;
5497ee25bc5SStella Laurenzo       try {
5507ee25bc5SStella Laurenzo         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
5517ee25bc5SStella Laurenzo       } catch (std::exception &e) {
5527ee25bc5SStella Laurenzo         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
5537ee25bc5SStella Laurenzo                 e.what());
5547ee25bc5SStella Laurenzo         pyHandler->hadError = true;
5557ee25bc5SStella Laurenzo       }
5567ee25bc5SStella Laurenzo     }
5577ee25bc5SStella Laurenzo 
5587ee25bc5SStella Laurenzo     pyDiagnostic->invalidate();
5597ee25bc5SStella Laurenzo     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
5607ee25bc5SStella Laurenzo   };
5617ee25bc5SStella Laurenzo   auto deleteCallback = +[](void *userData) {
5627ee25bc5SStella Laurenzo     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
5637ee25bc5SStella Laurenzo     assert(pyHandler->registeredID && "handler is not registered");
5647ee25bc5SStella Laurenzo     pyHandler->registeredID.reset();
5657ee25bc5SStella Laurenzo 
5667ee25bc5SStella Laurenzo     // Decrement reference, balancing the inc_ref() above.
5677ee25bc5SStella Laurenzo     py::object pyHandlerObject =
5687ee25bc5SStella Laurenzo         py::cast(pyHandler, py::return_value_policy::reference);
5697ee25bc5SStella Laurenzo     pyHandlerObject.dec_ref();
5707ee25bc5SStella Laurenzo   };
5717ee25bc5SStella Laurenzo 
5727ee25bc5SStella Laurenzo   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
5737ee25bc5SStella Laurenzo       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
5747ee25bc5SStella Laurenzo   return pyHandlerObject;
5757ee25bc5SStella Laurenzo }
5767ee25bc5SStella Laurenzo 
resolve()577436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
578436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
579436c6c9cSStella Laurenzo   if (!context) {
580436c6c9cSStella Laurenzo     throw SetPyError(
581436c6c9cSStella Laurenzo         PyExc_RuntimeError,
582436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
583436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
584436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
585436c6c9cSStella Laurenzo   }
586436c6c9cSStella Laurenzo   return *context;
587436c6c9cSStella Laurenzo }
588436c6c9cSStella Laurenzo 
589436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
590436c6c9cSStella Laurenzo // PyThreadContextEntry management
591436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
592436c6c9cSStella Laurenzo 
getStack()593436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
594436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
595436c6c9cSStella Laurenzo   return stack;
596436c6c9cSStella Laurenzo }
597436c6c9cSStella Laurenzo 
getTopOfStack()598436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
599436c6c9cSStella Laurenzo   auto &stack = getStack();
600436c6c9cSStella Laurenzo   if (stack.empty())
601436c6c9cSStella Laurenzo     return nullptr;
602436c6c9cSStella Laurenzo   return &stack.back();
603436c6c9cSStella Laurenzo }
604436c6c9cSStella Laurenzo 
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)605436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
606436c6c9cSStella Laurenzo                                 py::object insertionPoint,
607436c6c9cSStella Laurenzo                                 py::object location) {
608436c6c9cSStella Laurenzo   auto &stack = getStack();
609436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
610436c6c9cSStella Laurenzo                      std::move(location));
611436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
612436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
613436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
614436c6c9cSStella Laurenzo   if (stack.size() > 1) {
615436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
616436c6c9cSStella Laurenzo     auto &current = stack.back();
617436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
618436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
619436c6c9cSStella Laurenzo       if (!current.insertionPoint)
620436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
621436c6c9cSStella Laurenzo       if (!current.location)
622436c6c9cSStella Laurenzo         current.location = prev.location;
623436c6c9cSStella Laurenzo     }
624436c6c9cSStella Laurenzo   }
625436c6c9cSStella Laurenzo }
626436c6c9cSStella Laurenzo 
getContext()627436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
628436c6c9cSStella Laurenzo   if (!context)
629436c6c9cSStella Laurenzo     return nullptr;
630436c6c9cSStella Laurenzo   return py::cast<PyMlirContext *>(context);
631436c6c9cSStella Laurenzo }
632436c6c9cSStella Laurenzo 
getInsertionPoint()633436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
634436c6c9cSStella Laurenzo   if (!insertionPoint)
635436c6c9cSStella Laurenzo     return nullptr;
636436c6c9cSStella Laurenzo   return py::cast<PyInsertionPoint *>(insertionPoint);
637436c6c9cSStella Laurenzo }
638436c6c9cSStella Laurenzo 
getLocation()639436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
640436c6c9cSStella Laurenzo   if (!location)
641436c6c9cSStella Laurenzo     return nullptr;
642436c6c9cSStella Laurenzo   return py::cast<PyLocation *>(location);
643436c6c9cSStella Laurenzo }
644436c6c9cSStella Laurenzo 
getDefaultContext()645436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
646436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
647436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
648436c6c9cSStella Laurenzo }
649436c6c9cSStella Laurenzo 
getDefaultInsertionPoint()650436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
651436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
652436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
653436c6c9cSStella Laurenzo }
654436c6c9cSStella Laurenzo 
getDefaultLocation()655436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
656436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
657436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
658436c6c9cSStella Laurenzo }
659436c6c9cSStella Laurenzo 
pushContext(PyMlirContext & context)660436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
661436c6c9cSStella Laurenzo   py::object contextObj = py::cast(context);
662436c6c9cSStella Laurenzo   push(FrameKind::Context, /*context=*/contextObj,
663436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
664436c6c9cSStella Laurenzo        /*location=*/py::object());
665436c6c9cSStella Laurenzo   return contextObj;
666436c6c9cSStella Laurenzo }
667436c6c9cSStella Laurenzo 
popContext(PyMlirContext & context)668436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
669436c6c9cSStella Laurenzo   auto &stack = getStack();
670436c6c9cSStella Laurenzo   if (stack.empty())
671436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
672436c6c9cSStella Laurenzo   auto &tos = stack.back();
673436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
674436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
675436c6c9cSStella Laurenzo   stack.pop_back();
676436c6c9cSStella Laurenzo }
677436c6c9cSStella Laurenzo 
678436c6c9cSStella Laurenzo py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)679436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
680436c6c9cSStella Laurenzo   py::object contextObj =
681436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
682436c6c9cSStella Laurenzo   py::object insertionPointObj = py::cast(insertionPoint);
683436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
684436c6c9cSStella Laurenzo        /*context=*/contextObj,
685436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
686436c6c9cSStella Laurenzo        /*location=*/py::object());
687436c6c9cSStella Laurenzo   return insertionPointObj;
688436c6c9cSStella Laurenzo }
689436c6c9cSStella Laurenzo 
popInsertionPoint(PyInsertionPoint & insertionPoint)690436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
691436c6c9cSStella Laurenzo   auto &stack = getStack();
692436c6c9cSStella Laurenzo   if (stack.empty())
693436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
694436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
695436c6c9cSStella Laurenzo   auto &tos = stack.back();
696436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
697436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
698436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
699436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
700436c6c9cSStella Laurenzo   stack.pop_back();
701436c6c9cSStella Laurenzo }
702436c6c9cSStella Laurenzo 
pushLocation(PyLocation & location)703436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
704436c6c9cSStella Laurenzo   py::object contextObj = location.getContext().getObject();
705436c6c9cSStella Laurenzo   py::object locationObj = py::cast(location);
706436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
707436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
708436c6c9cSStella Laurenzo        /*location=*/locationObj);
709436c6c9cSStella Laurenzo   return locationObj;
710436c6c9cSStella Laurenzo }
711436c6c9cSStella Laurenzo 
popLocation(PyLocation & location)712436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
713436c6c9cSStella Laurenzo   auto &stack = getStack();
714436c6c9cSStella Laurenzo   if (stack.empty())
715436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
716436c6c9cSStella Laurenzo   auto &tos = stack.back();
717436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
718436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
719436c6c9cSStella Laurenzo   stack.pop_back();
720436c6c9cSStella Laurenzo }
721436c6c9cSStella Laurenzo 
722436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
7237ee25bc5SStella Laurenzo // PyDiagnostic*
7247ee25bc5SStella Laurenzo //------------------------------------------------------------------------------
7257ee25bc5SStella Laurenzo 
invalidate()7267ee25bc5SStella Laurenzo void PyDiagnostic::invalidate() {
7277ee25bc5SStella Laurenzo   valid = false;
7287ee25bc5SStella Laurenzo   if (materializedNotes) {
7297ee25bc5SStella Laurenzo     for (auto &noteObject : *materializedNotes) {
7307ee25bc5SStella Laurenzo       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
7317ee25bc5SStella Laurenzo       note->invalidate();
7327ee25bc5SStella Laurenzo     }
7337ee25bc5SStella Laurenzo   }
7347ee25bc5SStella Laurenzo }
7357ee25bc5SStella Laurenzo 
PyDiagnosticHandler(MlirContext context,py::object callback)7367ee25bc5SStella Laurenzo PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
7377ee25bc5SStella Laurenzo                                          py::object callback)
7387ee25bc5SStella Laurenzo     : context(context), callback(std::move(callback)) {}
7397ee25bc5SStella Laurenzo 
7406a38cbfbSMehdi Amini PyDiagnosticHandler::~PyDiagnosticHandler() = default;
7417ee25bc5SStella Laurenzo 
detach()7427ee25bc5SStella Laurenzo void PyDiagnosticHandler::detach() {
7437ee25bc5SStella Laurenzo   if (!registeredID)
7447ee25bc5SStella Laurenzo     return;
7457ee25bc5SStella Laurenzo   MlirDiagnosticHandlerID localID = *registeredID;
7467ee25bc5SStella Laurenzo   mlirContextDetachDiagnosticHandler(context, localID);
7477ee25bc5SStella Laurenzo   assert(!registeredID && "should have unregistered");
7487ee25bc5SStella Laurenzo   // Not strictly necessary but keeps stale pointers from being around to cause
7497ee25bc5SStella Laurenzo   // issues.
7507ee25bc5SStella Laurenzo   context = {nullptr};
7517ee25bc5SStella Laurenzo }
7527ee25bc5SStella Laurenzo 
checkValid()7537ee25bc5SStella Laurenzo void PyDiagnostic::checkValid() {
7547ee25bc5SStella Laurenzo   if (!valid) {
7557ee25bc5SStella Laurenzo     throw std::invalid_argument(
7567ee25bc5SStella Laurenzo         "Diagnostic is invalid (used outside of callback)");
7577ee25bc5SStella Laurenzo   }
7587ee25bc5SStella Laurenzo }
7597ee25bc5SStella Laurenzo 
getSeverity()7607ee25bc5SStella Laurenzo MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
7617ee25bc5SStella Laurenzo   checkValid();
7627ee25bc5SStella Laurenzo   return mlirDiagnosticGetSeverity(diagnostic);
7637ee25bc5SStella Laurenzo }
7647ee25bc5SStella Laurenzo 
getLocation()7657ee25bc5SStella Laurenzo PyLocation PyDiagnostic::getLocation() {
7667ee25bc5SStella Laurenzo   checkValid();
7677ee25bc5SStella Laurenzo   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
7687ee25bc5SStella Laurenzo   MlirContext context = mlirLocationGetContext(loc);
7697ee25bc5SStella Laurenzo   return PyLocation(PyMlirContext::forContext(context), loc);
7707ee25bc5SStella Laurenzo }
7717ee25bc5SStella Laurenzo 
getMessage()7727ee25bc5SStella Laurenzo py::str PyDiagnostic::getMessage() {
7737ee25bc5SStella Laurenzo   checkValid();
7747ee25bc5SStella Laurenzo   py::object fileObject = py::module::import("io").attr("StringIO")();
7757ee25bc5SStella Laurenzo   PyFileAccumulator accum(fileObject, /*binary=*/false);
7767ee25bc5SStella Laurenzo   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
7777ee25bc5SStella Laurenzo   return fileObject.attr("getvalue")();
7787ee25bc5SStella Laurenzo }
7797ee25bc5SStella Laurenzo 
getNotes()7807ee25bc5SStella Laurenzo py::tuple PyDiagnostic::getNotes() {
7817ee25bc5SStella Laurenzo   checkValid();
7827ee25bc5SStella Laurenzo   if (materializedNotes)
7837ee25bc5SStella Laurenzo     return *materializedNotes;
7847ee25bc5SStella Laurenzo   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
7857ee25bc5SStella Laurenzo   materializedNotes = py::tuple(numNotes);
7867ee25bc5SStella Laurenzo   for (intptr_t i = 0; i < numNotes; ++i) {
7877ee25bc5SStella Laurenzo     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
788*65aedd33Srkayaith     materializedNotes.value()[i] = PyDiagnostic(noteDiag);
7897ee25bc5SStella Laurenzo   }
7907ee25bc5SStella Laurenzo   return *materializedNotes;
7917ee25bc5SStella Laurenzo }
7927ee25bc5SStella Laurenzo 
7937ee25bc5SStella Laurenzo //------------------------------------------------------------------------------
7945e83a5b4SStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
795436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
796436c6c9cSStella Laurenzo 
getDialectForKey(const std::string & key,bool attrError)797436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
798436c6c9cSStella Laurenzo                                          bool attrError) {
799f8479d9dSRiver Riddle   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
800f8479d9dSRiver Riddle                                                     {key.data(), key.size()});
801436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
802436c6c9cSStella Laurenzo     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
803436c6c9cSStella Laurenzo                      Twine("Dialect '") + key + "' not found");
804436c6c9cSStella Laurenzo   }
805436c6c9cSStella Laurenzo   return dialect;
806436c6c9cSStella Laurenzo }
807436c6c9cSStella Laurenzo 
getCapsule()8085e83a5b4SStella Laurenzo py::object PyDialectRegistry::getCapsule() {
8095e83a5b4SStella Laurenzo   return py::reinterpret_steal<py::object>(
8105e83a5b4SStella Laurenzo       mlirPythonDialectRegistryToCapsule(*this));
8115e83a5b4SStella Laurenzo }
8125e83a5b4SStella Laurenzo 
createFromCapsule(py::object capsule)8135e83a5b4SStella Laurenzo PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
8145e83a5b4SStella Laurenzo   MlirDialectRegistry rawRegistry =
8155e83a5b4SStella Laurenzo       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
8165e83a5b4SStella Laurenzo   if (mlirDialectRegistryIsNull(rawRegistry))
8175e83a5b4SStella Laurenzo     throw py::error_already_set();
8185e83a5b4SStella Laurenzo   return PyDialectRegistry(rawRegistry);
8195e83a5b4SStella Laurenzo }
8205e83a5b4SStella Laurenzo 
821436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
822436c6c9cSStella Laurenzo // PyLocation
823436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
824436c6c9cSStella Laurenzo 
getCapsule()825436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() {
826436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
827436c6c9cSStella Laurenzo }
828436c6c9cSStella Laurenzo 
createFromCapsule(py::object capsule)829436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) {
830436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
831436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
832436c6c9cSStella Laurenzo     throw py::error_already_set();
833436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
834436c6c9cSStella Laurenzo                     rawLoc);
835436c6c9cSStella Laurenzo }
836436c6c9cSStella Laurenzo 
contextEnter()837436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() {
838436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushLocation(*this);
839436c6c9cSStella Laurenzo }
840436c6c9cSStella Laurenzo 
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)8411fc096afSMehdi Amini void PyLocation::contextExit(const pybind11::object &excType,
8421fc096afSMehdi Amini                              const pybind11::object &excVal,
8431fc096afSMehdi Amini                              const pybind11::object &excTb) {
844436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
845436c6c9cSStella Laurenzo }
846436c6c9cSStella Laurenzo 
resolve()847436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
848436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
849436c6c9cSStella Laurenzo   if (!location) {
850436c6c9cSStella Laurenzo     throw SetPyError(
851436c6c9cSStella Laurenzo         PyExc_RuntimeError,
852436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
853436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
854436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
855436c6c9cSStella Laurenzo   }
856436c6c9cSStella Laurenzo   return *location;
857436c6c9cSStella Laurenzo }
858436c6c9cSStella Laurenzo 
859436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
860436c6c9cSStella Laurenzo // PyModule
861436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
862436c6c9cSStella Laurenzo 
PyModule(PyMlirContextRef contextRef,MlirModule module)863436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
864436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
865436c6c9cSStella Laurenzo 
~PyModule()866436c6c9cSStella Laurenzo PyModule::~PyModule() {
867436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
868436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
869436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
870436c6c9cSStella Laurenzo          "destroying module not in live map");
871436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
872436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
873436c6c9cSStella Laurenzo }
874436c6c9cSStella Laurenzo 
forModule(MlirModule module)875436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
876436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
877436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
878436c6c9cSStella Laurenzo 
879436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
880436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
881436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
882436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
883436c6c9cSStella Laurenzo     // Create.
884436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
885436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
886436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
887436c6c9cSStella Laurenzo     // Just be explicit.
888436c6c9cSStella Laurenzo     py::object pyRef =
889436c6c9cSStella Laurenzo         py::cast(unownedModule, py::return_value_policy::take_ownership);
890436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
891436c6c9cSStella Laurenzo     liveModules[module.ptr] =
892436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
893436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
894436c6c9cSStella Laurenzo   }
895436c6c9cSStella Laurenzo   // Use existing.
896436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
897436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
898436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
899436c6c9cSStella Laurenzo }
900436c6c9cSStella Laurenzo 
createFromCapsule(py::object capsule)901436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) {
902436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
903436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
904436c6c9cSStella Laurenzo     throw py::error_already_set();
905436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
906436c6c9cSStella Laurenzo }
907436c6c9cSStella Laurenzo 
getCapsule()908436c6c9cSStella Laurenzo py::object PyModule::getCapsule() {
909436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
910436c6c9cSStella Laurenzo }
911436c6c9cSStella Laurenzo 
912436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
913436c6c9cSStella Laurenzo // PyOperation
914436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
915436c6c9cSStella Laurenzo 
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)916436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
917436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
918436c6c9cSStella Laurenzo 
~PyOperation()919436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
92049745f87SMike Urbach   // If the operation has already been invalidated there is nothing to do.
92149745f87SMike Urbach   if (!valid)
92249745f87SMike Urbach     return;
923436c6c9cSStella Laurenzo   auto &liveOperations = getContext()->liveOperations;
924436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 1 &&
925436c6c9cSStella Laurenzo          "destroying operation not in live map");
926436c6c9cSStella Laurenzo   liveOperations.erase(operation.ptr);
927436c6c9cSStella Laurenzo   if (!isAttached()) {
928436c6c9cSStella Laurenzo     mlirOperationDestroy(operation);
929436c6c9cSStella Laurenzo   }
930436c6c9cSStella Laurenzo }
931436c6c9cSStella Laurenzo 
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)932436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
933436c6c9cSStella Laurenzo                                            MlirOperation operation,
934436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
935436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
936436c6c9cSStella Laurenzo   // Create.
937436c6c9cSStella Laurenzo   PyOperation *unownedOperation =
938436c6c9cSStella Laurenzo       new PyOperation(std::move(contextRef), operation);
939436c6c9cSStella Laurenzo   // Note that the default return value policy on cast is automatic_reference,
940436c6c9cSStella Laurenzo   // which does not take ownership (delete will not be called).
941436c6c9cSStella Laurenzo   // Just be explicit.
942436c6c9cSStella Laurenzo   py::object pyRef =
943436c6c9cSStella Laurenzo       py::cast(unownedOperation, py::return_value_policy::take_ownership);
944436c6c9cSStella Laurenzo   unownedOperation->handle = pyRef;
945436c6c9cSStella Laurenzo   if (parentKeepAlive) {
946436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
947436c6c9cSStella Laurenzo   }
948436c6c9cSStella Laurenzo   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
949436c6c9cSStella Laurenzo   return PyOperationRef(unownedOperation, std::move(pyRef));
950436c6c9cSStella Laurenzo }
951436c6c9cSStella Laurenzo 
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)952436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
953436c6c9cSStella Laurenzo                                          MlirOperation operation,
954436c6c9cSStella Laurenzo                                          py::object parentKeepAlive) {
955436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
956436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
957436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
958436c6c9cSStella Laurenzo     // Create.
959436c6c9cSStella Laurenzo     return createInstance(std::move(contextRef), operation,
960436c6c9cSStella Laurenzo                           std::move(parentKeepAlive));
961436c6c9cSStella Laurenzo   }
962436c6c9cSStella Laurenzo   // Use existing.
963436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
964436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
965436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
966436c6c9cSStella Laurenzo }
967436c6c9cSStella Laurenzo 
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)968436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
969436c6c9cSStella Laurenzo                                            MlirOperation operation,
970436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
971436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
972436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
973436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
974436c6c9cSStella Laurenzo   (void)liveOperations;
975436c6c9cSStella Laurenzo 
976436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
977436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
978436c6c9cSStella Laurenzo   created->attached = false;
979436c6c9cSStella Laurenzo   return created;
980436c6c9cSStella Laurenzo }
981436c6c9cSStella Laurenzo 
checkValid() const982436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
983436c6c9cSStella Laurenzo   if (!valid) {
984436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
985436c6c9cSStella Laurenzo   }
986436c6c9cSStella Laurenzo }
987436c6c9cSStella Laurenzo 
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope,bool assumeVerified)988436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary,
989436c6c9cSStella Laurenzo                             llvm::Optional<int64_t> largeElementsLimit,
990436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
991ace1d0adSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope,
992ace1d0adSStella Laurenzo                             bool assumeVerified) {
993436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
994436c6c9cSStella Laurenzo   operation.checkValid();
995436c6c9cSStella Laurenzo   if (fileObject.is_none())
996436c6c9cSStella Laurenzo     fileObject = py::module::import("sys").attr("stdout");
997436c6c9cSStella Laurenzo 
998ace1d0adSStella Laurenzo   if (!assumeVerified && !printGenericOpForm &&
999ace1d0adSStella Laurenzo       !mlirOperationVerify(operation)) {
1000ace1d0adSStella Laurenzo     std::string message("// Verification failed, printing generic form\n");
1001ace1d0adSStella Laurenzo     if (binary) {
1002ace1d0adSStella Laurenzo       fileObject.attr("write")(py::bytes(message));
1003ace1d0adSStella Laurenzo     } else {
1004ace1d0adSStella Laurenzo       fileObject.attr("write")(py::str(message));
1005ace1d0adSStella Laurenzo     }
1006436c6c9cSStella Laurenzo     printGenericOpForm = true;
1007436c6c9cSStella Laurenzo   }
1008436c6c9cSStella Laurenzo 
1009436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1010436c6c9cSStella Laurenzo   if (largeElementsLimit)
1011436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1012436c6c9cSStella Laurenzo   if (enableDebugInfo)
1013436c6c9cSStella Laurenzo     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1014436c6c9cSStella Laurenzo   if (printGenericOpForm)
1015436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1016bccf27d9SMark Browning   if (useLocalScope)
1017bccf27d9SMark Browning     mlirOpPrintingFlagsUseLocalScope(flags);
1018436c6c9cSStella Laurenzo 
1019436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
1020436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1021436c6c9cSStella Laurenzo                               accum.getUserData());
1022436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
1023436c6c9cSStella Laurenzo }
1024436c6c9cSStella Laurenzo 
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope,bool assumeVerified)1025436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary,
1026436c6c9cSStella Laurenzo                                    llvm::Optional<int64_t> largeElementsLimit,
1027436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
1028ace1d0adSStella Laurenzo                                    bool printGenericOpForm, bool useLocalScope,
1029ace1d0adSStella Laurenzo                                    bool assumeVerified) {
1030436c6c9cSStella Laurenzo   py::object fileObject;
1031436c6c9cSStella Laurenzo   if (binary) {
1032436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("BytesIO")();
1033436c6c9cSStella Laurenzo   } else {
1034436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("StringIO")();
1035436c6c9cSStella Laurenzo   }
1036436c6c9cSStella Laurenzo   print(fileObject, /*binary=*/binary,
1037436c6c9cSStella Laurenzo         /*largeElementsLimit=*/largeElementsLimit,
1038436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
1039436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
1040436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
1041ace1d0adSStella Laurenzo         /*useLocalScope=*/useLocalScope,
1042ace1d0adSStella Laurenzo         /*assumeVerified=*/assumeVerified);
1043436c6c9cSStella Laurenzo 
1044436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
1045436c6c9cSStella Laurenzo }
1046436c6c9cSStella Laurenzo 
moveAfter(PyOperationBase & other)104724685aaeSAlex Zinenko void PyOperationBase::moveAfter(PyOperationBase &other) {
104824685aaeSAlex Zinenko   PyOperation &operation = getOperation();
104924685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
105024685aaeSAlex Zinenko   operation.checkValid();
105124685aaeSAlex Zinenko   otherOp.checkValid();
105224685aaeSAlex Zinenko   mlirOperationMoveAfter(operation, otherOp);
105324685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
105424685aaeSAlex Zinenko }
105524685aaeSAlex Zinenko 
moveBefore(PyOperationBase & other)105624685aaeSAlex Zinenko void PyOperationBase::moveBefore(PyOperationBase &other) {
105724685aaeSAlex Zinenko   PyOperation &operation = getOperation();
105824685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
105924685aaeSAlex Zinenko   operation.checkValid();
106024685aaeSAlex Zinenko   otherOp.checkValid();
106124685aaeSAlex Zinenko   mlirOperationMoveBefore(operation, otherOp);
106224685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
106324685aaeSAlex Zinenko }
106424685aaeSAlex Zinenko 
getParentOperation()10651689dadeSJohn Demme llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
106649745f87SMike Urbach   checkValid();
1067436c6c9cSStella Laurenzo   if (!isAttached())
1068436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1069436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
1070436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
10711689dadeSJohn Demme     return {};
1072436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
1073436c6c9cSStella Laurenzo }
1074436c6c9cSStella Laurenzo 
getBlock()1075436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
107649745f87SMike Urbach   checkValid();
10771689dadeSJohn Demme   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1078436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
1079436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
10801689dadeSJohn Demme   assert(parentOperation && "Operation has no parent");
10811689dadeSJohn Demme   return PyBlock{std::move(*parentOperation), block};
1082436c6c9cSStella Laurenzo }
1083436c6c9cSStella Laurenzo 
getCapsule()10840126e906SJohn Demme py::object PyOperation::getCapsule() {
108549745f87SMike Urbach   checkValid();
10860126e906SJohn Demme   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
10870126e906SJohn Demme }
10880126e906SJohn Demme 
createFromCapsule(py::object capsule)10890126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) {
10900126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
10910126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
10920126e906SJohn Demme     throw py::error_already_set();
10930126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
10940126e906SJohn Demme   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
10950126e906SJohn Demme       .releaseObject();
10960126e906SJohn Demme }
10970126e906SJohn Demme 
maybeInsertOperation(PyOperationRef & op,const py::object & maybeIp)1098774818c0SDominik Grewe static void maybeInsertOperation(PyOperationRef &op,
1099774818c0SDominik Grewe                                  const py::object &maybeIp) {
1100774818c0SDominik Grewe   // InsertPoint active?
1101774818c0SDominik Grewe   if (!maybeIp.is(py::cast(false))) {
1102774818c0SDominik Grewe     PyInsertionPoint *ip;
1103774818c0SDominik Grewe     if (maybeIp.is_none()) {
1104774818c0SDominik Grewe       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1105774818c0SDominik Grewe     } else {
1106774818c0SDominik Grewe       ip = py::cast<PyInsertionPoint *>(maybeIp);
1107774818c0SDominik Grewe     }
1108774818c0SDominik Grewe     if (ip)
1109774818c0SDominik Grewe       ip->insert(*op.get());
1110774818c0SDominik Grewe   }
1111774818c0SDominik Grewe }
1112774818c0SDominik Grewe 
create(const std::string & name,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,const py::object & maybeIp)1113436c6c9cSStella Laurenzo py::object PyOperation::create(
11141fc096afSMehdi Amini     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1115436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyValue *>> operands,
1116436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
1117436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
11181fc096afSMehdi Amini     DefaultingPyLocation location, const py::object &maybeIp) {
1119436c6c9cSStella Laurenzo   llvm::SmallVector<MlirValue, 4> mlirOperands;
1120436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
1121436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1122436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1123436c6c9cSStella Laurenzo 
1124436c6c9cSStella Laurenzo   // General parameter validation.
1125436c6c9cSStella Laurenzo   if (regions < 0)
1126436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1127436c6c9cSStella Laurenzo 
1128436c6c9cSStella Laurenzo   // Unpack/validate operands.
1129436c6c9cSStella Laurenzo   if (operands) {
1130436c6c9cSStella Laurenzo     mlirOperands.reserve(operands->size());
1131436c6c9cSStella Laurenzo     for (PyValue *operand : *operands) {
1132436c6c9cSStella Laurenzo       if (!operand)
1133436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1134436c6c9cSStella Laurenzo       mlirOperands.push_back(operand->get());
1135436c6c9cSStella Laurenzo     }
1136436c6c9cSStella Laurenzo   }
1137436c6c9cSStella Laurenzo 
1138436c6c9cSStella Laurenzo   // Unpack/validate results.
1139436c6c9cSStella Laurenzo   if (results) {
1140436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
1141436c6c9cSStella Laurenzo     for (PyType *result : *results) {
1142436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
1143436c6c9cSStella Laurenzo       if (!result)
1144436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1145436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
1146436c6c9cSStella Laurenzo     }
1147436c6c9cSStella Laurenzo   }
1148436c6c9cSStella Laurenzo   // Unpack/validate attributes.
1149436c6c9cSStella Laurenzo   if (attributes) {
1150436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
1151436c6c9cSStella Laurenzo     for (auto &it : *attributes) {
1152436c6c9cSStella Laurenzo       std::string key;
1153436c6c9cSStella Laurenzo       try {
1154436c6c9cSStella Laurenzo         key = it.first.cast<std::string>();
1155436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1156436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
1157436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
1158436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
1159436c6c9cSStella Laurenzo         throw py::cast_error(msg);
1160436c6c9cSStella Laurenzo       }
1161436c6c9cSStella Laurenzo       try {
1162436c6c9cSStella Laurenzo         auto &attribute = it.second.cast<PyAttribute &>();
1163436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
1164436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
1165436c6c9cSStella Laurenzo       } catch (py::reference_cast_error &) {
1166436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
1167436c6c9cSStella Laurenzo         std::string msg =
1168436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
1169436c6c9cSStella Laurenzo             "\" when attempting to create the operation \"" + name + "\"";
1170436c6c9cSStella Laurenzo         throw py::cast_error(msg);
1171436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1172436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute value for the key \"" + key +
1173436c6c9cSStella Laurenzo                           "\" when attempting to create the operation \"" +
1174436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
1175436c6c9cSStella Laurenzo         throw py::cast_error(msg);
1176436c6c9cSStella Laurenzo       }
1177436c6c9cSStella Laurenzo     }
1178436c6c9cSStella Laurenzo   }
1179436c6c9cSStella Laurenzo   // Unpack/validate successors.
1180436c6c9cSStella Laurenzo   if (successors) {
1181436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
1182436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
1183436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
1184436c6c9cSStella Laurenzo       if (!successor)
1185436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1186436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
1187436c6c9cSStella Laurenzo     }
1188436c6c9cSStella Laurenzo   }
1189436c6c9cSStella Laurenzo 
1190436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
1191436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
1192436c6c9cSStella Laurenzo   MlirOperationState state =
1193436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
1194436c6c9cSStella Laurenzo   if (!mlirOperands.empty())
1195436c6c9cSStella Laurenzo     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1196436c6c9cSStella Laurenzo                                   mlirOperands.data());
1197436c6c9cSStella Laurenzo   if (!mlirResults.empty())
1198436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
1199436c6c9cSStella Laurenzo                                  mlirResults.data());
1200436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
1201436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
1202436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
1203436c6c9cSStella Laurenzo     // on.
1204436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1205436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
1206436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
1207436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1208436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1209436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
1210436c6c9cSStella Laurenzo           it.second));
1211436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1212436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1213436c6c9cSStella Laurenzo   }
1214436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
1215436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1216436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
1217436c6c9cSStella Laurenzo   if (regions) {
1218436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1219436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
1220436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
1221436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
1222436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1223436c6c9cSStella Laurenzo                                       mlirRegions.data());
1224436c6c9cSStella Laurenzo   }
1225436c6c9cSStella Laurenzo 
1226436c6c9cSStella Laurenzo   // Construct the operation.
1227436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1228436c6c9cSStella Laurenzo   PyOperationRef created =
1229436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1230774818c0SDominik Grewe   maybeInsertOperation(created, maybeIp);
1231436c6c9cSStella Laurenzo 
1232436c6c9cSStella Laurenzo   return created->createOpView();
1233436c6c9cSStella Laurenzo }
1234436c6c9cSStella Laurenzo 
clone(const py::object & maybeIp)1235774818c0SDominik Grewe py::object PyOperation::clone(const py::object &maybeIp) {
1236774818c0SDominik Grewe   MlirOperation clonedOperation = mlirOperationClone(operation);
1237774818c0SDominik Grewe   PyOperationRef cloned =
1238774818c0SDominik Grewe       PyOperation::createDetached(getContext(), clonedOperation);
1239774818c0SDominik Grewe   maybeInsertOperation(cloned, maybeIp);
1240774818c0SDominik Grewe 
1241774818c0SDominik Grewe   return cloned->createOpView();
1242774818c0SDominik Grewe }
1243774818c0SDominik Grewe 
createOpView()1244436c6c9cSStella Laurenzo py::object PyOperation::createOpView() {
124549745f87SMike Urbach   checkValid();
1246436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1247436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1248436c6c9cSStella Laurenzo   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1249436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1250436c6c9cSStella Laurenzo   if (opViewClass)
1251436c6c9cSStella Laurenzo     return (*opViewClass)(getRef().getObject());
1252436c6c9cSStella Laurenzo   return py::cast(PyOpView(getRef().getObject()));
1253436c6c9cSStella Laurenzo }
1254436c6c9cSStella Laurenzo 
erase()125549745f87SMike Urbach void PyOperation::erase() {
125649745f87SMike Urbach   checkValid();
125749745f87SMike Urbach   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
125849745f87SMike Urbach   // Python reference to a child operation is live. All children should also
125949745f87SMike Urbach   // have their `valid` bit set to false.
126049745f87SMike Urbach   auto &liveOperations = getContext()->liveOperations;
126149745f87SMike Urbach   if (liveOperations.count(operation.ptr))
126249745f87SMike Urbach     liveOperations.erase(operation.ptr);
126349745f87SMike Urbach   mlirOperationDestroy(operation);
126449745f87SMike Urbach   valid = false;
126549745f87SMike Urbach }
126649745f87SMike Urbach 
1267436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1268436c6c9cSStella Laurenzo // PyOpView
1269436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1270436c6c9cSStella Laurenzo 
buildGeneric(const py::object & cls,py::list resultTypeList,py::list operandList,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,llvm::Optional<int> regions,DefaultingPyLocation location,const py::object & maybeIp)12714f415216SMehdi Amini py::object PyOpView::buildGeneric(
12724f415216SMehdi Amini     const py::object &cls, py::list resultTypeList, py::list operandList,
1273436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
1274436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors,
12754f415216SMehdi Amini     llvm::Optional<int> regions, DefaultingPyLocation location,
12764f415216SMehdi Amini     const py::object &maybeIp) {
1277436c6c9cSStella Laurenzo   PyMlirContextRef context = location->getContext();
1278436c6c9cSStella Laurenzo   // Class level operation construction metadata.
1279436c6c9cSStella Laurenzo   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1280436c6c9cSStella Laurenzo   // Operand and result segment specs are either none, which does no
1281436c6c9cSStella Laurenzo   // variadic unpacking, or a list of ints with segment sizes, where each
1282436c6c9cSStella Laurenzo   // element is either a positive number (typically 1 for a scalar) or -1 to
1283436c6c9cSStella Laurenzo   // indicate that it is derived from the length of the same-indexed operand
1284436c6c9cSStella Laurenzo   // or result (implying that it is a list at that position).
1285436c6c9cSStella Laurenzo   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1286436c6c9cSStella Laurenzo   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1287436c6c9cSStella Laurenzo 
12888d05a288SStella Laurenzo   std::vector<uint32_t> operandSegmentLengths;
12898d05a288SStella Laurenzo   std::vector<uint32_t> resultSegmentLengths;
1290436c6c9cSStella Laurenzo 
1291436c6c9cSStella Laurenzo   // Validate/determine region count.
1292436c6c9cSStella Laurenzo   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1293436c6c9cSStella Laurenzo   int opMinRegionCount = std::get<0>(opRegionSpec);
1294436c6c9cSStella Laurenzo   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1295436c6c9cSStella Laurenzo   if (!regions) {
1296436c6c9cSStella Laurenzo     regions = opMinRegionCount;
1297436c6c9cSStella Laurenzo   }
1298436c6c9cSStella Laurenzo   if (*regions < opMinRegionCount) {
1299436c6c9cSStella Laurenzo     throw py::value_error(
1300436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1301436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1302436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1303436c6c9cSStella Laurenzo             .str());
1304436c6c9cSStella Laurenzo   }
1305436c6c9cSStella Laurenzo   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1306436c6c9cSStella Laurenzo     throw py::value_error(
1307436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1308436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1309436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1310436c6c9cSStella Laurenzo             .str());
1311436c6c9cSStella Laurenzo   }
1312436c6c9cSStella Laurenzo 
1313436c6c9cSStella Laurenzo   // Unpack results.
1314436c6c9cSStella Laurenzo   std::vector<PyType *> resultTypes;
1315436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1316436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1317436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1318e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(resultTypeList)) {
1319436c6c9cSStella Laurenzo       try {
1320436c6c9cSStella Laurenzo         resultTypes.push_back(py::cast<PyType *>(it.value()));
1321436c6c9cSStella Laurenzo         if (!resultTypes.back())
1322436c6c9cSStella Laurenzo           throw py::cast_error();
1323436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1324436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Result ") +
1325436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1326436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1327436c6c9cSStella Laurenzo                                   .str());
1328436c6c9cSStella Laurenzo       }
1329436c6c9cSStella Laurenzo     }
1330436c6c9cSStella Laurenzo   } else {
1331436c6c9cSStella Laurenzo     // Sized result unpacking.
1332436c6c9cSStella Laurenzo     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1333436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1334436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1335436c6c9cSStella Laurenzo                              "\" requires " +
1336436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1337436c6c9cSStella Laurenzo                              " result segments but was provided " +
1338436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1339436c6c9cSStella Laurenzo                                 .str());
1340436c6c9cSStella Laurenzo     }
1341436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1342e4853be2SMehdi Amini     for (const auto &it :
1343436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1344436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1345436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1346436c6c9cSStella Laurenzo         // Unpack unary element.
1347436c6c9cSStella Laurenzo         try {
13486981e5ecSAlex Zinenko           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1349436c6c9cSStella Laurenzo           if (resultType) {
1350436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1351436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1352436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1353436c6c9cSStella Laurenzo             // Allowed to be optional.
1354436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1355436c6c9cSStella Laurenzo           } else {
1356436c6c9cSStella Laurenzo             throw py::cast_error("was None and result is not optional");
1357436c6c9cSStella Laurenzo           }
1358436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1359436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1360436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1361436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1362436c6c9cSStella Laurenzo                                  ")")
1363436c6c9cSStella Laurenzo                                     .str());
1364436c6c9cSStella Laurenzo         }
1365436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1366436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1367436c6c9cSStella Laurenzo         try {
1368436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1369436c6c9cSStella Laurenzo             // Treat it as an empty list.
1370436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1371436c6c9cSStella Laurenzo           } else {
1372436c6c9cSStella Laurenzo             // Unpack the list.
1373436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1374436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1375436c6c9cSStella Laurenzo               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1376436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1377436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1378436c6c9cSStella Laurenzo               }
1379436c6c9cSStella Laurenzo             }
1380436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(segment.size());
1381436c6c9cSStella Laurenzo           }
1382436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1383436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1384436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1385436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1386436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1387436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1388436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1389436c6c9cSStella Laurenzo                                  err.what() + ")")
1390436c6c9cSStella Laurenzo                                     .str());
1391436c6c9cSStella Laurenzo         }
1392436c6c9cSStella Laurenzo       } else {
1393436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1394436c6c9cSStella Laurenzo       }
1395436c6c9cSStella Laurenzo     }
1396436c6c9cSStella Laurenzo   }
1397436c6c9cSStella Laurenzo 
1398436c6c9cSStella Laurenzo   // Unpack operands.
1399436c6c9cSStella Laurenzo   std::vector<PyValue *> operands;
1400436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1401436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1402436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1403e4853be2SMehdi Amini     for (const auto &it : llvm::enumerate(operandList)) {
1404436c6c9cSStella Laurenzo       try {
1405436c6c9cSStella Laurenzo         operands.push_back(py::cast<PyValue *>(it.value()));
1406436c6c9cSStella Laurenzo         if (!operands.back())
1407436c6c9cSStella Laurenzo           throw py::cast_error();
1408436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1409436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Operand ") +
1410436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1411436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1412436c6c9cSStella Laurenzo                                   .str());
1413436c6c9cSStella Laurenzo       }
1414436c6c9cSStella Laurenzo     }
1415436c6c9cSStella Laurenzo   } else {
1416436c6c9cSStella Laurenzo     // Sized operand unpacking.
1417436c6c9cSStella Laurenzo     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1418436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1419436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1420436c6c9cSStella Laurenzo                              "\" requires " +
1421436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1422436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1423436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1424436c6c9cSStella Laurenzo                                 .str());
1425436c6c9cSStella Laurenzo     }
1426436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1427e4853be2SMehdi Amini     for (const auto &it :
1428436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1429436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1430436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1431436c6c9cSStella Laurenzo         // Unpack unary element.
1432436c6c9cSStella Laurenzo         try {
143302b6fb21SMehdi Amini           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1434436c6c9cSStella Laurenzo           if (operandValue) {
1435436c6c9cSStella Laurenzo             operands.push_back(operandValue);
1436436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(1);
1437436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1438436c6c9cSStella Laurenzo             // Allowed to be optional.
1439436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1440436c6c9cSStella Laurenzo           } else {
1441436c6c9cSStella Laurenzo             throw py::cast_error("was None and operand is not optional");
1442436c6c9cSStella Laurenzo           }
1443436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1444436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1445436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1446436c6c9cSStella Laurenzo                                  name + "\" must be a Value (" + err.what() +
1447436c6c9cSStella Laurenzo                                  ")")
1448436c6c9cSStella Laurenzo                                     .str());
1449436c6c9cSStella Laurenzo         }
1450436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1451436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1452436c6c9cSStella Laurenzo         try {
1453436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1454436c6c9cSStella Laurenzo             // Treat it as an empty list.
1455436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1456436c6c9cSStella Laurenzo           } else {
1457436c6c9cSStella Laurenzo             // Unpack the list.
1458436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1459436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1460436c6c9cSStella Laurenzo               operands.push_back(py::cast<PyValue *>(segmentItem));
1461436c6c9cSStella Laurenzo               if (!operands.back()) {
1462436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1463436c6c9cSStella Laurenzo               }
1464436c6c9cSStella Laurenzo             }
1465436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(segment.size());
1466436c6c9cSStella Laurenzo           }
1467436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1468436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1469436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1470436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1471436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1472436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1473436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
1474436c6c9cSStella Laurenzo                                  err.what() + ")")
1475436c6c9cSStella Laurenzo                                     .str());
1476436c6c9cSStella Laurenzo         }
1477436c6c9cSStella Laurenzo       } else {
1478436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1479436c6c9cSStella Laurenzo       }
1480436c6c9cSStella Laurenzo     }
1481436c6c9cSStella Laurenzo   }
1482436c6c9cSStella Laurenzo 
1483436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
1484436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1485436c6c9cSStella Laurenzo     // Dup.
1486436c6c9cSStella Laurenzo     if (attributes) {
1487436c6c9cSStella Laurenzo       attributes = py::dict(*attributes);
1488436c6c9cSStella Laurenzo     } else {
1489436c6c9cSStella Laurenzo       attributes = py::dict();
1490436c6c9cSStella Laurenzo     }
1491436c6c9cSStella Laurenzo     if (attributes->contains("result_segment_sizes") ||
1492436c6c9cSStella Laurenzo         attributes->contains("operand_segment_sizes")) {
1493436c6c9cSStella Laurenzo       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1494436c6c9cSStella Laurenzo                             "'operand_segment_sizes' attribute is unsupported. "
1495436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
1496436c6c9cSStella Laurenzo     }
1497436c6c9cSStella Laurenzo 
1498436c6c9cSStella Laurenzo     // Add result_segment_sizes attribute.
1499436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
1500436c6c9cSStella Laurenzo       int64_t size = resultSegmentLengths.size();
15018d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
15028d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1503436c6c9cSStella Laurenzo           resultSegmentLengths.size(), resultSegmentLengths.data());
1504436c6c9cSStella Laurenzo       (*attributes)["result_segment_sizes"] =
1505436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1506436c6c9cSStella Laurenzo     }
1507436c6c9cSStella Laurenzo 
1508436c6c9cSStella Laurenzo     // Add operand_segment_sizes attribute.
1509436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
1510436c6c9cSStella Laurenzo       int64_t size = operandSegmentLengths.size();
15118d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
15128d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1513436c6c9cSStella Laurenzo           operandSegmentLengths.size(), operandSegmentLengths.data());
1514436c6c9cSStella Laurenzo       (*attributes)["operand_segment_sizes"] =
1515436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1516436c6c9cSStella Laurenzo     }
1517436c6c9cSStella Laurenzo   }
1518436c6c9cSStella Laurenzo 
1519436c6c9cSStella Laurenzo   // Delegate to create.
1520337c937dSMehdi Amini   return PyOperation::create(name,
1521436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
1522436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
1523436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
1524436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
1525337c937dSMehdi Amini                              /*regions=*/*regions, location, maybeIp);
1526436c6c9cSStella Laurenzo }
1527436c6c9cSStella Laurenzo 
PyOpView(const py::object & operationObject)15281fc096afSMehdi Amini PyOpView::PyOpView(const py::object &operationObject)
1529436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
1530436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
1531436c6c9cSStella Laurenzo     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1532436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
1533436c6c9cSStella Laurenzo 
createRawSubclass(const py::object & userClass)15341fc096afSMehdi Amini py::object PyOpView::createRawSubclass(const py::object &userClass) {
1535436c6c9cSStella Laurenzo   // This is... a little gross. The typical pattern is to have a pure python
1536436c6c9cSStella Laurenzo   // class that extends OpView like:
1537436c6c9cSStella Laurenzo   //   class AddFOp(_cext.ir.OpView):
1538436c6c9cSStella Laurenzo   //     def __init__(self, loc, lhs, rhs):
1539436c6c9cSStella Laurenzo   //       operation = loc.context.create_operation(
1540436c6c9cSStella Laurenzo   //           "addf", lhs, rhs, results=[lhs.type])
1541436c6c9cSStella Laurenzo   //       super().__init__(operation)
1542436c6c9cSStella Laurenzo   //
1543436c6c9cSStella Laurenzo   // I.e. The goal of the user facing type is to provide a nice constructor
1544436c6c9cSStella Laurenzo   // that has complete freedom for the op under construction. This is at odds
1545436c6c9cSStella Laurenzo   // with our other desire to sometimes create this object by just passing an
1546436c6c9cSStella Laurenzo   // operation (to initialize the base class). We could do *arg and **kwargs
1547436c6c9cSStella Laurenzo   // munging to try to make it work, but instead, we synthesize a new class
1548436c6c9cSStella Laurenzo   // on the fly which extends this user class (AddFOp in this example) and
1549436c6c9cSStella Laurenzo   // *give it* the base class's __init__ method, thus bypassing the
1550436c6c9cSStella Laurenzo   // intermediate subclass's __init__ method entirely. While slightly,
1551436c6c9cSStella Laurenzo   // underhanded, this is safe/legal because the type hierarchy has not changed
1552436c6c9cSStella Laurenzo   // (we just added a new leaf) and we aren't mucking around with __new__.
1553436c6c9cSStella Laurenzo   // Typically, this new class will be stored on the original as "_Raw" and will
1554436c6c9cSStella Laurenzo   // be used for casts and other things that need a variant of the class that
1555436c6c9cSStella Laurenzo   // is initialized purely from an operation.
1556436c6c9cSStella Laurenzo   py::object parentMetaclass =
1557436c6c9cSStella Laurenzo       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1558436c6c9cSStella Laurenzo   py::dict attributes;
1559436c6c9cSStella Laurenzo   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1560436c6c9cSStella Laurenzo   // now.
1561436c6c9cSStella Laurenzo   //   auto opViewType = py::type::of<PyOpView>();
1562436c6c9cSStella Laurenzo   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1563436c6c9cSStella Laurenzo   attributes["__init__"] = opViewType.attr("__init__");
1564436c6c9cSStella Laurenzo   py::str origName = userClass.attr("__name__");
1565436c6c9cSStella Laurenzo   py::str newName = py::str("_") + origName;
1566436c6c9cSStella Laurenzo   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1567436c6c9cSStella Laurenzo }
1568436c6c9cSStella Laurenzo 
1569436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1570436c6c9cSStella Laurenzo // PyInsertionPoint.
1571436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1572436c6c9cSStella Laurenzo 
PyInsertionPoint(PyBlock & block)1573436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1574436c6c9cSStella Laurenzo 
PyInsertionPoint(PyOperationBase & beforeOperationBase)1575436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1576436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
1577436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
1578436c6c9cSStella Laurenzo 
insert(PyOperationBase & operationBase)1579436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1580436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
1581436c6c9cSStella Laurenzo   if (operation.isAttached())
1582436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError,
1583436c6c9cSStella Laurenzo                      "Attempt to insert operation that is already attached");
1584436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
1585436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
1586436c6c9cSStella Laurenzo   if (refOperation) {
1587436c6c9cSStella Laurenzo     // Insert before operation.
1588436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
1589436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
1590436c6c9cSStella Laurenzo   } else {
1591436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
1592436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
1593436c6c9cSStella Laurenzo     // failures later).
1594436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1595436c6c9cSStella Laurenzo       throw py::index_error("Cannot insert operation at the end of a block "
1596436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
1597436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
1598436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
1599436c6c9cSStella Laurenzo     }
1600436c6c9cSStella Laurenzo   }
1601436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1602436c6c9cSStella Laurenzo   operation.setAttached();
1603436c6c9cSStella Laurenzo }
1604436c6c9cSStella Laurenzo 
atBlockBegin(PyBlock & block)1605436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1606436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1607436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
1608436c6c9cSStella Laurenzo     // Just insert at end.
1609436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
1610436c6c9cSStella Laurenzo   }
1611436c6c9cSStella Laurenzo 
1612436c6c9cSStella Laurenzo   // Insert before first op.
1613436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
1614436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
1615436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
1616436c6c9cSStella Laurenzo }
1617436c6c9cSStella Laurenzo 
atBlockTerminator(PyBlock & block)1618436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1619436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1620436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
1621436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1622436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1623436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
1624436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1625436c6c9cSStella Laurenzo }
1626436c6c9cSStella Laurenzo 
contextEnter()1627436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() {
1628436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushInsertionPoint(*this);
1629436c6c9cSStella Laurenzo }
1630436c6c9cSStella Laurenzo 
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)16311fc096afSMehdi Amini void PyInsertionPoint::contextExit(const pybind11::object &excType,
16321fc096afSMehdi Amini                                    const pybind11::object &excVal,
16331fc096afSMehdi Amini                                    const pybind11::object &excTb) {
1634436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
1635436c6c9cSStella Laurenzo }
1636436c6c9cSStella Laurenzo 
1637436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1638436c6c9cSStella Laurenzo // PyAttribute.
1639436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1640436c6c9cSStella Laurenzo 
operator ==(const PyAttribute & other)1641436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) {
1642436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
1643436c6c9cSStella Laurenzo }
1644436c6c9cSStella Laurenzo 
getCapsule()1645436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() {
1646436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1647436c6c9cSStella Laurenzo }
1648436c6c9cSStella Laurenzo 
createFromCapsule(py::object capsule)1649436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1650436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1651436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
1652436c6c9cSStella Laurenzo     throw py::error_already_set();
1653436c6c9cSStella Laurenzo   return PyAttribute(
1654436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1655436c6c9cSStella Laurenzo }
1656436c6c9cSStella Laurenzo 
1657436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1658436c6c9cSStella Laurenzo // PyNamedAttribute.
1659436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1660436c6c9cSStella Laurenzo 
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1661436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1662436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
1663436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
1664436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
1665436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
1666436c6c9cSStella Laurenzo       attr);
1667436c6c9cSStella Laurenzo }
1668436c6c9cSStella Laurenzo 
1669436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1670436c6c9cSStella Laurenzo // PyType.
1671436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1672436c6c9cSStella Laurenzo 
operator ==(const PyType & other)1673436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) {
1674436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
1675436c6c9cSStella Laurenzo }
1676436c6c9cSStella Laurenzo 
getCapsule()1677436c6c9cSStella Laurenzo py::object PyType::getCapsule() {
1678436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1679436c6c9cSStella Laurenzo }
1680436c6c9cSStella Laurenzo 
createFromCapsule(py::object capsule)1681436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) {
1682436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1683436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
1684436c6c9cSStella Laurenzo     throw py::error_already_set();
1685436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1686436c6c9cSStella Laurenzo                 rawType);
1687436c6c9cSStella Laurenzo }
1688436c6c9cSStella Laurenzo 
1689436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1690436c6c9cSStella Laurenzo // PyValue and subclases.
1691436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1692436c6c9cSStella Laurenzo 
getCapsule()16933f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() {
16943f3d1c90SMike Urbach   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
16953f3d1c90SMike Urbach }
16963f3d1c90SMike Urbach 
createFromCapsule(pybind11::object capsule)16973f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) {
16983f3d1c90SMike Urbach   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
16993f3d1c90SMike Urbach   if (mlirValueIsNull(value))
17003f3d1c90SMike Urbach     throw py::error_already_set();
17013f3d1c90SMike Urbach   MlirOperation owner;
17023f3d1c90SMike Urbach   if (mlirValueIsAOpResult(value))
17033f3d1c90SMike Urbach     owner = mlirOpResultGetOwner(value);
17043f3d1c90SMike Urbach   if (mlirValueIsABlockArgument(value))
17053f3d1c90SMike Urbach     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
17063f3d1c90SMike Urbach   if (mlirOperationIsNull(owner))
17073f3d1c90SMike Urbach     throw py::error_already_set();
17083f3d1c90SMike Urbach   MlirContext ctx = mlirOperationGetContext(owner);
17093f3d1c90SMike Urbach   PyOperationRef ownerRef =
17103f3d1c90SMike Urbach       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
17113f3d1c90SMike Urbach   return PyValue(ownerRef, value);
17123f3d1c90SMike Urbach }
17133f3d1c90SMike Urbach 
171430d61893SAlex Zinenko //------------------------------------------------------------------------------
171530d61893SAlex Zinenko // PySymbolTable.
171630d61893SAlex Zinenko //------------------------------------------------------------------------------
171730d61893SAlex Zinenko 
PySymbolTable(PyOperationBase & operation)171830d61893SAlex Zinenko PySymbolTable::PySymbolTable(PyOperationBase &operation)
171930d61893SAlex Zinenko     : operation(operation.getOperation().getRef()) {
172030d61893SAlex Zinenko   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
172130d61893SAlex Zinenko   if (mlirSymbolTableIsNull(symbolTable)) {
172230d61893SAlex Zinenko     throw py::cast_error("Operation is not a Symbol Table.");
172330d61893SAlex Zinenko   }
172430d61893SAlex Zinenko }
172530d61893SAlex Zinenko 
dunderGetItem(const std::string & name)172630d61893SAlex Zinenko py::object PySymbolTable::dunderGetItem(const std::string &name) {
172730d61893SAlex Zinenko   operation->checkValid();
172830d61893SAlex Zinenko   MlirOperation symbol = mlirSymbolTableLookup(
172930d61893SAlex Zinenko       symbolTable, mlirStringRefCreate(name.data(), name.length()));
173030d61893SAlex Zinenko   if (mlirOperationIsNull(symbol))
173130d61893SAlex Zinenko     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
173230d61893SAlex Zinenko 
173330d61893SAlex Zinenko   return PyOperation::forOperation(operation->getContext(), symbol,
173430d61893SAlex Zinenko                                    operation.getObject())
173530d61893SAlex Zinenko       ->createOpView();
173630d61893SAlex Zinenko }
173730d61893SAlex Zinenko 
erase(PyOperationBase & symbol)173830d61893SAlex Zinenko void PySymbolTable::erase(PyOperationBase &symbol) {
173930d61893SAlex Zinenko   operation->checkValid();
174030d61893SAlex Zinenko   symbol.getOperation().checkValid();
174130d61893SAlex Zinenko   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
174230d61893SAlex Zinenko   // The operation is also erased, so we must invalidate it. There may be Python
174330d61893SAlex Zinenko   // references to this operation so we don't want to delete it from the list of
174430d61893SAlex Zinenko   // live operations here.
174530d61893SAlex Zinenko   symbol.getOperation().valid = false;
174630d61893SAlex Zinenko }
174730d61893SAlex Zinenko 
dunderDel(const std::string & name)174830d61893SAlex Zinenko void PySymbolTable::dunderDel(const std::string &name) {
174930d61893SAlex Zinenko   py::object operation = dunderGetItem(name);
175030d61893SAlex Zinenko   erase(py::cast<PyOperationBase &>(operation));
175130d61893SAlex Zinenko }
175230d61893SAlex Zinenko 
insert(PyOperationBase & symbol)175330d61893SAlex Zinenko PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
175430d61893SAlex Zinenko   operation->checkValid();
175530d61893SAlex Zinenko   symbol.getOperation().checkValid();
175630d61893SAlex Zinenko   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
175730d61893SAlex Zinenko       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
175830d61893SAlex Zinenko   if (mlirAttributeIsNull(symbolAttr))
175930d61893SAlex Zinenko     throw py::value_error("Expected operation to have a symbol name.");
176030d61893SAlex Zinenko   return PyAttribute(
176130d61893SAlex Zinenko       symbol.getOperation().getContext(),
176230d61893SAlex Zinenko       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
176330d61893SAlex Zinenko }
176430d61893SAlex Zinenko 
getSymbolName(PyOperationBase & symbol)1765bdc31837SStella Laurenzo PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1766bdc31837SStella Laurenzo   // Op must already be a symbol.
1767bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
1768bdc31837SStella Laurenzo   operation.checkValid();
1769bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1770bdc31837SStella Laurenzo   MlirAttribute existingNameAttr =
1771bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
1772bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingNameAttr))
1773bdc31837SStella Laurenzo     throw py::value_error("Expected operation to have a symbol name.");
1774bdc31837SStella Laurenzo   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1775bdc31837SStella Laurenzo }
1776bdc31837SStella Laurenzo 
setSymbolName(PyOperationBase & symbol,const std::string & name)1777bdc31837SStella Laurenzo void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1778bdc31837SStella Laurenzo                                   const std::string &name) {
1779bdc31837SStella Laurenzo   // Op must already be a symbol.
1780bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
1781bdc31837SStella Laurenzo   operation.checkValid();
1782bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1783bdc31837SStella Laurenzo   MlirAttribute existingNameAttr =
1784bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
1785bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingNameAttr))
1786bdc31837SStella Laurenzo     throw py::value_error("Expected operation to have a symbol name.");
1787bdc31837SStella Laurenzo   MlirAttribute newNameAttr =
1788bdc31837SStella Laurenzo       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1789bdc31837SStella Laurenzo   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1790bdc31837SStella Laurenzo }
1791bdc31837SStella Laurenzo 
getVisibility(PyOperationBase & symbol)1792bdc31837SStella Laurenzo PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1793bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
1794bdc31837SStella Laurenzo   operation.checkValid();
1795bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1796bdc31837SStella Laurenzo   MlirAttribute existingVisAttr =
1797bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
1798bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingVisAttr))
1799bdc31837SStella Laurenzo     throw py::value_error("Expected operation to have a symbol visibility.");
1800bdc31837SStella Laurenzo   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1801bdc31837SStella Laurenzo }
1802bdc31837SStella Laurenzo 
setVisibility(PyOperationBase & symbol,const std::string & visibility)1803bdc31837SStella Laurenzo void PySymbolTable::setVisibility(PyOperationBase &symbol,
1804bdc31837SStella Laurenzo                                   const std::string &visibility) {
1805bdc31837SStella Laurenzo   if (visibility != "public" && visibility != "private" &&
1806bdc31837SStella Laurenzo       visibility != "nested")
1807bdc31837SStella Laurenzo     throw py::value_error(
1808bdc31837SStella Laurenzo         "Expected visibility to be 'public', 'private' or 'nested'");
1809bdc31837SStella Laurenzo   PyOperation &operation = symbol.getOperation();
1810bdc31837SStella Laurenzo   operation.checkValid();
1811bdc31837SStella Laurenzo   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1812bdc31837SStella Laurenzo   MlirAttribute existingVisAttr =
1813bdc31837SStella Laurenzo       mlirOperationGetAttributeByName(operation.get(), attrName);
1814bdc31837SStella Laurenzo   if (mlirAttributeIsNull(existingVisAttr))
1815bdc31837SStella Laurenzo     throw py::value_error("Expected operation to have a symbol visibility.");
1816bdc31837SStella Laurenzo   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1817bdc31837SStella Laurenzo                                                toMlirStringRef(visibility));
1818bdc31837SStella Laurenzo   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1819bdc31837SStella Laurenzo }
1820bdc31837SStella Laurenzo 
replaceAllSymbolUses(const std::string & oldSymbol,const std::string & newSymbol,PyOperationBase & from)1821bdc31837SStella Laurenzo void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1822bdc31837SStella Laurenzo                                          const std::string &newSymbol,
1823bdc31837SStella Laurenzo                                          PyOperationBase &from) {
1824bdc31837SStella Laurenzo   PyOperation &fromOperation = from.getOperation();
1825bdc31837SStella Laurenzo   fromOperation.checkValid();
1826bdc31837SStella Laurenzo   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1827bdc31837SStella Laurenzo           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1828bdc31837SStella Laurenzo           from.getOperation())))
1829bdc31837SStella Laurenzo 
1830bdc31837SStella Laurenzo     throw py::value_error("Symbol rename failed");
1831bdc31837SStella Laurenzo }
1832bdc31837SStella Laurenzo 
walkSymbolTables(PyOperationBase & from,bool allSymUsesVisible,py::object callback)1833bdc31837SStella Laurenzo void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1834bdc31837SStella Laurenzo                                      bool allSymUsesVisible,
1835bdc31837SStella Laurenzo                                      py::object callback) {
1836bdc31837SStella Laurenzo   PyOperation &fromOperation = from.getOperation();
1837bdc31837SStella Laurenzo   fromOperation.checkValid();
1838bdc31837SStella Laurenzo   struct UserData {
1839bdc31837SStella Laurenzo     PyMlirContextRef context;
1840bdc31837SStella Laurenzo     py::object callback;
1841bdc31837SStella Laurenzo     bool gotException;
1842bdc31837SStella Laurenzo     std::string exceptionWhat;
1843bdc31837SStella Laurenzo     py::object exceptionType;
1844bdc31837SStella Laurenzo   };
1845bdc31837SStella Laurenzo   UserData userData{
1846bdc31837SStella Laurenzo       fromOperation.getContext(), std::move(callback), false, {}, {}};
1847bdc31837SStella Laurenzo   mlirSymbolTableWalkSymbolTables(
1848bdc31837SStella Laurenzo       fromOperation.get(), allSymUsesVisible,
1849bdc31837SStella Laurenzo       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1850bdc31837SStella Laurenzo         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1851bdc31837SStella Laurenzo         auto pyFoundOp =
1852bdc31837SStella Laurenzo             PyOperation::forOperation(calleeUserData->context, foundOp);
1853bdc31837SStella Laurenzo         if (calleeUserData->gotException)
1854bdc31837SStella Laurenzo           return;
1855bdc31837SStella Laurenzo         try {
1856bdc31837SStella Laurenzo           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1857bdc31837SStella Laurenzo         } catch (py::error_already_set &e) {
1858bdc31837SStella Laurenzo           calleeUserData->gotException = true;
1859bdc31837SStella Laurenzo           calleeUserData->exceptionWhat = e.what();
1860bdc31837SStella Laurenzo           calleeUserData->exceptionType = e.type();
1861bdc31837SStella Laurenzo         }
1862bdc31837SStella Laurenzo       },
1863bdc31837SStella Laurenzo       static_cast<void *>(&userData));
1864bdc31837SStella Laurenzo   if (userData.gotException) {
1865bdc31837SStella Laurenzo     std::string message("Exception raised in callback: ");
1866bdc31837SStella Laurenzo     message.append(userData.exceptionWhat);
1867337c937dSMehdi Amini     throw std::runtime_error(message);
1868bdc31837SStella Laurenzo   }
1869bdc31837SStella Laurenzo }
1870bdc31837SStella Laurenzo 
1871436c6c9cSStella Laurenzo namespace {
1872436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be
1873436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed
1874436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes.
1875436c6c9cSStella Laurenzo template <typename DerivedTy>
1876436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue {
1877436c6c9cSStella Laurenzo public:
1878436c6c9cSStella Laurenzo   // Derived classes must define statics for:
1879436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
1880436c6c9cSStella Laurenzo   //   const char *pyClassName
1881436c6c9cSStella Laurenzo   // and redefine bindDerived.
1882436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, PyValue>;
1883436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirValue);
1884436c6c9cSStella Laurenzo 
1885436c6c9cSStella Laurenzo   PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1886436c6c9cSStella Laurenzo   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1887436c6c9cSStella Laurenzo       : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1888436c6c9cSStella Laurenzo   PyConcreteValue(PyValue &orig)
1889436c6c9cSStella Laurenzo       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1890436c6c9cSStella Laurenzo 
1891436c6c9cSStella Laurenzo   /// Attempts to cast the original value to the derived type and throws on
1892436c6c9cSStella Laurenzo   /// type mismatches.
castFrom(PyValue & orig)1893436c6c9cSStella Laurenzo   static MlirValue castFrom(PyValue &orig) {
1894436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig.get())) {
1895436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1896436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1897436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
1898436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
1899436c6c9cSStella Laurenzo     }
1900436c6c9cSStella Laurenzo     return orig.get();
1901436c6c9cSStella Laurenzo   }
1902436c6c9cSStella Laurenzo 
1903436c6c9cSStella Laurenzo   /// Binds the Python module objects to functions of this class.
bind(py::module & m)1904436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1905f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1906a6e7d024SStella Laurenzo     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1907a6e7d024SStella Laurenzo     cls.def_static(
1908a6e7d024SStella Laurenzo         "isinstance",
1909a6e7d024SStella Laurenzo         [](PyValue &otherValue) -> bool {
191078f2dae0SAlex Zinenko           return DerivedTy::isaFunction(otherValue);
1911a6e7d024SStella Laurenzo         },
1912a6e7d024SStella Laurenzo         py::arg("other_value"));
1913436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
1914436c6c9cSStella Laurenzo   }
1915436c6c9cSStella Laurenzo 
1916436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1917436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
1918436c6c9cSStella Laurenzo };
1919436c6c9cSStella Laurenzo 
1920436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
1921436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1922436c6c9cSStella Laurenzo public:
1923436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1924436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
1925436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1926436c6c9cSStella Laurenzo 
bindDerived(ClassTy & c)1927436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1928436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1929436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
1930436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
1931436c6c9cSStella Laurenzo     });
1932436c6c9cSStella Laurenzo     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1933436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
1934436c6c9cSStella Laurenzo     });
1935a6e7d024SStella Laurenzo     c.def(
1936a6e7d024SStella Laurenzo         "set_type",
1937a6e7d024SStella Laurenzo         [](PyBlockArgument &self, PyType type) {
1938436c6c9cSStella Laurenzo           return mlirBlockArgumentSetType(self.get(), type);
1939a6e7d024SStella Laurenzo         },
1940a6e7d024SStella Laurenzo         py::arg("type"));
1941436c6c9cSStella Laurenzo   }
1942436c6c9cSStella Laurenzo };
1943436c6c9cSStella Laurenzo 
1944436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult.
1945436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> {
1946436c6c9cSStella Laurenzo public:
1947436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1948436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResult";
1949436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1950436c6c9cSStella Laurenzo 
bindDerived(ClassTy & c)1951436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1952436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyOpResult &self) {
1953436c6c9cSStella Laurenzo       assert(
1954436c6c9cSStella Laurenzo           mlirOperationEqual(self.getParentOperation()->get(),
1955436c6c9cSStella Laurenzo                              mlirOpResultGetOwner(self.get())) &&
1956436c6c9cSStella Laurenzo           "expected the owner of the value in Python to match that in the IR");
19576ff74f96SMike Urbach       return self.getParentOperation().getObject();
1958436c6c9cSStella Laurenzo     });
1959436c6c9cSStella Laurenzo     c.def_property_readonly("result_number", [](PyOpResult &self) {
1960436c6c9cSStella Laurenzo       return mlirOpResultGetResultNumber(self.get());
1961436c6c9cSStella Laurenzo     });
1962436c6c9cSStella Laurenzo   }
1963436c6c9cSStella Laurenzo };
1964436c6c9cSStella Laurenzo 
1965ed9e52f3SAlex Zinenko /// Returns the list of types of the values held by container.
1966ed9e52f3SAlex Zinenko template <typename Container>
getValueTypes(Container & container,PyMlirContextRef & context)1967ed9e52f3SAlex Zinenko static std::vector<PyType> getValueTypes(Container &container,
1968ed9e52f3SAlex Zinenko                                          PyMlirContextRef &context) {
1969ed9e52f3SAlex Zinenko   std::vector<PyType> result;
1970ee168fb9SAlex Zinenko   result.reserve(container.size());
1971ee168fb9SAlex Zinenko   for (int i = 0, e = container.size(); i < e; ++i) {
1972ed9e52f3SAlex Zinenko     result.push_back(
1973ed9e52f3SAlex Zinenko         PyType(context, mlirValueGetType(container.getElement(i).get())));
1974ed9e52f3SAlex Zinenko   }
1975ed9e52f3SAlex Zinenko   return result;
1976ed9e52f3SAlex Zinenko }
1977ed9e52f3SAlex Zinenko 
1978436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
1979436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
1980436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
1981436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
1982afeda4b9SAlex Zinenko class PyBlockArgumentList
1983afeda4b9SAlex Zinenko     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1984436c6c9cSStella Laurenzo public:
1985afeda4b9SAlex Zinenko   static constexpr const char *pyClassName = "BlockArgumentList";
1986436c6c9cSStella Laurenzo 
PyBlockArgumentList(PyOperationRef operation,MlirBlock block,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1987afeda4b9SAlex Zinenko   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1988afeda4b9SAlex Zinenko                       intptr_t startIndex = 0, intptr_t length = -1,
1989afeda4b9SAlex Zinenko                       intptr_t step = 1)
1990afeda4b9SAlex Zinenko       : Sliceable(startIndex,
1991afeda4b9SAlex Zinenko                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1992afeda4b9SAlex Zinenko                   step),
1993afeda4b9SAlex Zinenko         operation(std::move(operation)), block(block) {}
1994afeda4b9SAlex Zinenko 
bindDerived(ClassTy & c)1995ee168fb9SAlex Zinenko   static void bindDerived(ClassTy &c) {
1996ee168fb9SAlex Zinenko     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1997ee168fb9SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
1998ee168fb9SAlex Zinenko     });
1999ee168fb9SAlex Zinenko   }
2000ee168fb9SAlex Zinenko 
2001ee168fb9SAlex Zinenko private:
2002ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
2003ee168fb9SAlex Zinenko   friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2004ee168fb9SAlex Zinenko 
2005afeda4b9SAlex Zinenko   /// Returns the number of arguments in the list.
getRawNumElements()2006ee168fb9SAlex Zinenko   intptr_t getRawNumElements() {
2007436c6c9cSStella Laurenzo     operation->checkValid();
2008436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
2009436c6c9cSStella Laurenzo   }
2010436c6c9cSStella Laurenzo 
2011ee168fb9SAlex Zinenko   /// Returns `pos`-the element in the list.
getRawElement(intptr_t pos)2012ee168fb9SAlex Zinenko   PyBlockArgument getRawElement(intptr_t pos) {
2013afeda4b9SAlex Zinenko     MlirValue argument = mlirBlockGetArgument(block, pos);
2014afeda4b9SAlex Zinenko     return PyBlockArgument(operation, argument);
2015436c6c9cSStella Laurenzo   }
2016436c6c9cSStella Laurenzo 
2017afeda4b9SAlex Zinenko   /// Returns a sublist of this list.
slice(intptr_t startIndex,intptr_t length,intptr_t step)2018afeda4b9SAlex Zinenko   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2019afeda4b9SAlex Zinenko                             intptr_t step) {
2020afeda4b9SAlex Zinenko     return PyBlockArgumentList(operation, block, startIndex, length, step);
2021436c6c9cSStella Laurenzo   }
2022436c6c9cSStella Laurenzo 
2023436c6c9cSStella Laurenzo   PyOperationRef operation;
2024436c6c9cSStella Laurenzo   MlirBlock block;
2025436c6c9cSStella Laurenzo };
2026436c6c9cSStella Laurenzo 
2027436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
2028436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
2029436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
2030436c6c9cSStella Laurenzo /// operation.
2031436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2032436c6c9cSStella Laurenzo public:
2033436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
2034436c6c9cSStella Laurenzo 
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2035436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2036436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
2037436c6c9cSStella Laurenzo       : Sliceable(startIndex,
2038436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2039436c6c9cSStella Laurenzo                                : length,
2040436c6c9cSStella Laurenzo                   step),
2041436c6c9cSStella Laurenzo         operation(operation) {}
2042436c6c9cSStella Laurenzo 
dunderSetItem(intptr_t index,PyValue value)2043ee168fb9SAlex Zinenko   void dunderSetItem(intptr_t index, PyValue value) {
2044ee168fb9SAlex Zinenko     index = wrapIndex(index);
2045ee168fb9SAlex Zinenko     mlirOperationSetOperand(operation->get(), index, value.get());
2046ee168fb9SAlex Zinenko   }
2047ee168fb9SAlex Zinenko 
bindDerived(ClassTy & c)2048ee168fb9SAlex Zinenko   static void bindDerived(ClassTy &c) {
2049ee168fb9SAlex Zinenko     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2050ee168fb9SAlex Zinenko   }
2051ee168fb9SAlex Zinenko 
2052ee168fb9SAlex Zinenko private:
2053ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
2054ee168fb9SAlex Zinenko   friend class Sliceable<PyOpOperandList, PyValue>;
2055ee168fb9SAlex Zinenko 
getRawNumElements()2056ee168fb9SAlex Zinenko   intptr_t getRawNumElements() {
2057436c6c9cSStella Laurenzo     operation->checkValid();
2058436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
2059436c6c9cSStella Laurenzo   }
2060436c6c9cSStella Laurenzo 
getRawElement(intptr_t pos)2061ee168fb9SAlex Zinenko   PyValue getRawElement(intptr_t pos) {
20625664c5e2SJohn Demme     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
20635664c5e2SJohn Demme     MlirOperation owner;
20645664c5e2SJohn Demme     if (mlirValueIsAOpResult(operand))
20655664c5e2SJohn Demme       owner = mlirOpResultGetOwner(operand);
20665664c5e2SJohn Demme     else if (mlirValueIsABlockArgument(operand))
20675664c5e2SJohn Demme       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
20685664c5e2SJohn Demme     else
20695664c5e2SJohn Demme       assert(false && "Value must be an block arg or op result.");
20705664c5e2SJohn Demme     PyOperationRef pyOwner =
20715664c5e2SJohn Demme         PyOperation::forOperation(operation->getContext(), owner);
20725664c5e2SJohn Demme     return PyValue(pyOwner, operand);
2073436c6c9cSStella Laurenzo   }
2074436c6c9cSStella Laurenzo 
slice(intptr_t startIndex,intptr_t length,intptr_t step)2075436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2076436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
2077436c6c9cSStella Laurenzo   }
2078436c6c9cSStella Laurenzo 
2079436c6c9cSStella Laurenzo   PyOperationRef operation;
2080436c6c9cSStella Laurenzo };
2081436c6c9cSStella Laurenzo 
2082436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive
2083436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
2084436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
2085436c6c9cSStella Laurenzo /// operation.
2086436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2087436c6c9cSStella Laurenzo public:
2088436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResultList";
2089436c6c9cSStella Laurenzo 
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)2090436c6c9cSStella Laurenzo   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2091436c6c9cSStella Laurenzo                  intptr_t length = -1, intptr_t step = 1)
2092436c6c9cSStella Laurenzo       : Sliceable(startIndex,
2093436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumResults(operation->get())
2094436c6c9cSStella Laurenzo                                : length,
2095436c6c9cSStella Laurenzo                   step),
2096436c6c9cSStella Laurenzo         operation(operation) {}
2097436c6c9cSStella Laurenzo 
bindDerived(ClassTy & c)2098ee168fb9SAlex Zinenko   static void bindDerived(ClassTy &c) {
2099ee168fb9SAlex Zinenko     c.def_property_readonly("types", [](PyOpResultList &self) {
2100ee168fb9SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
2101ee168fb9SAlex Zinenko     });
2102ee168fb9SAlex Zinenko   }
2103ee168fb9SAlex Zinenko 
2104ee168fb9SAlex Zinenko private:
2105ee168fb9SAlex Zinenko   /// Give the parent CRTP class access to hook implementations below.
2106ee168fb9SAlex Zinenko   friend class Sliceable<PyOpResultList, PyOpResult>;
2107ee168fb9SAlex Zinenko 
getRawNumElements()2108ee168fb9SAlex Zinenko   intptr_t getRawNumElements() {
2109436c6c9cSStella Laurenzo     operation->checkValid();
2110436c6c9cSStella Laurenzo     return mlirOperationGetNumResults(operation->get());
2111436c6c9cSStella Laurenzo   }
2112436c6c9cSStella Laurenzo 
getRawElement(intptr_t index)2113ee168fb9SAlex Zinenko   PyOpResult getRawElement(intptr_t index) {
2114436c6c9cSStella Laurenzo     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2115436c6c9cSStella Laurenzo     return PyOpResult(value);
2116436c6c9cSStella Laurenzo   }
2117436c6c9cSStella Laurenzo 
slice(intptr_t startIndex,intptr_t length,intptr_t step)2118436c6c9cSStella Laurenzo   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2119436c6c9cSStella Laurenzo     return PyOpResultList(operation, startIndex, length, step);
2120436c6c9cSStella Laurenzo   }
2121436c6c9cSStella Laurenzo 
2122436c6c9cSStella Laurenzo   PyOperationRef operation;
2123436c6c9cSStella Laurenzo };
2124436c6c9cSStella Laurenzo 
2125436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
2126436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
2127436c6c9cSStella Laurenzo class PyOpAttributeMap {
2128436c6c9cSStella Laurenzo public:
PyOpAttributeMap(PyOperationRef operation)21291fc096afSMehdi Amini   PyOpAttributeMap(PyOperationRef operation)
21301fc096afSMehdi Amini       : operation(std::move(operation)) {}
2131436c6c9cSStella Laurenzo 
dunderGetItemNamed(const std::string & name)2132436c6c9cSStella Laurenzo   PyAttribute dunderGetItemNamed(const std::string &name) {
2133436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2134436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
2135436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
2136436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
2137436c6c9cSStella Laurenzo                        "attempt to access a non-existent attribute");
2138436c6c9cSStella Laurenzo     }
2139436c6c9cSStella Laurenzo     return PyAttribute(operation->getContext(), attr);
2140436c6c9cSStella Laurenzo   }
2141436c6c9cSStella Laurenzo 
dunderGetItemIndexed(intptr_t index)2142436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2143436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
2144436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
2145436c6c9cSStella Laurenzo                        "attempt to access out of bounds attribute");
2146436c6c9cSStella Laurenzo     }
2147436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
2148436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
2149436c6c9cSStella Laurenzo     return PyNamedAttribute(
2150436c6c9cSStella Laurenzo         namedAttr.attribute,
2151120591e1SRiver Riddle         std::string(mlirIdentifierStr(namedAttr.name).data,
2152120591e1SRiver Riddle                     mlirIdentifierStr(namedAttr.name).length));
2153436c6c9cSStella Laurenzo   }
2154436c6c9cSStella Laurenzo 
dunderSetItem(const std::string & name,const PyAttribute & attr)21551fc096afSMehdi Amini   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2156436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2157436c6c9cSStella Laurenzo                                     attr);
2158436c6c9cSStella Laurenzo   }
2159436c6c9cSStella Laurenzo 
dunderDelItem(const std::string & name)2160436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
2161436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2162436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
2163436c6c9cSStella Laurenzo     if (!removed)
2164436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
2165436c6c9cSStella Laurenzo                        "attempt to delete a non-existent attribute");
2166436c6c9cSStella Laurenzo   }
2167436c6c9cSStella Laurenzo 
dunderLen()2168436c6c9cSStella Laurenzo   intptr_t dunderLen() {
2169436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
2170436c6c9cSStella Laurenzo   }
2171436c6c9cSStella Laurenzo 
dunderContains(const std::string & name)2172436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
2173436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2174436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
2175436c6c9cSStella Laurenzo   }
2176436c6c9cSStella Laurenzo 
bind(py::module & m)2177436c6c9cSStella Laurenzo   static void bind(py::module &m) {
2178f05ff4f7SStella Laurenzo     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2179436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
2180436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
2181436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2182436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2183436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2184436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2185436c6c9cSStella Laurenzo   }
2186436c6c9cSStella Laurenzo 
2187436c6c9cSStella Laurenzo private:
2188436c6c9cSStella Laurenzo   PyOperationRef operation;
2189436c6c9cSStella Laurenzo };
2190436c6c9cSStella Laurenzo 
2191be0a7e9fSMehdi Amini } // namespace
2192436c6c9cSStella Laurenzo 
2193436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2194436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
2195436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
2196436c6c9cSStella Laurenzo 
populateIRCore(py::module & m)2197436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) {
2198436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
21997ee25bc5SStella Laurenzo   // Enums.
22007ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
22017ee25bc5SStella Laurenzo   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
22027ee25bc5SStella Laurenzo       .value("ERROR", MlirDiagnosticError)
22037ee25bc5SStella Laurenzo       .value("WARNING", MlirDiagnosticWarning)
22047ee25bc5SStella Laurenzo       .value("NOTE", MlirDiagnosticNote)
22057ee25bc5SStella Laurenzo       .value("REMARK", MlirDiagnosticRemark);
22067ee25bc5SStella Laurenzo 
22077ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
22087ee25bc5SStella Laurenzo   // Mapping of Diagnostics.
22097ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
22107ee25bc5SStella Laurenzo   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
22117ee25bc5SStella Laurenzo       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
22127ee25bc5SStella Laurenzo       .def_property_readonly("location", &PyDiagnostic::getLocation)
22137ee25bc5SStella Laurenzo       .def_property_readonly("message", &PyDiagnostic::getMessage)
22147ee25bc5SStella Laurenzo       .def_property_readonly("notes", &PyDiagnostic::getNotes)
22157ee25bc5SStella Laurenzo       .def("__str__", [](PyDiagnostic &self) -> py::str {
22167ee25bc5SStella Laurenzo         if (!self.isValid())
22177ee25bc5SStella Laurenzo           return "<Invalid Diagnostic>";
22187ee25bc5SStella Laurenzo         return self.getMessage();
22197ee25bc5SStella Laurenzo       });
22207ee25bc5SStella Laurenzo 
22217ee25bc5SStella Laurenzo   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
22227ee25bc5SStella Laurenzo       .def("detach", &PyDiagnosticHandler::detach)
22237ee25bc5SStella Laurenzo       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
22247ee25bc5SStella Laurenzo       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
22257ee25bc5SStella Laurenzo       .def("__enter__", &PyDiagnosticHandler::contextEnter)
22267ee25bc5SStella Laurenzo       .def("__exit__", &PyDiagnosticHandler::contextExit);
22277ee25bc5SStella Laurenzo 
22287ee25bc5SStella Laurenzo   //----------------------------------------------------------------------------
22294acd8457SAlex Zinenko   // Mapping of MlirContext.
22305e83a5b4SStella Laurenzo   // Note that this is exported as _BaseContext. The containing, Python level
22315e83a5b4SStella Laurenzo   // __init__.py will subclass it with site-specific functionality and set a
22325e83a5b4SStella Laurenzo   // "Context" attribute on this module.
2233436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
22345e83a5b4SStella Laurenzo   py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2235436c6c9cSStella Laurenzo       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2236436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2237436c6c9cSStella Laurenzo       .def("_get_context_again",
2238436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
2239436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2240436c6c9cSStella Laurenzo              return ref.releaseObject();
2241436c6c9cSStella Laurenzo            })
2242436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
22436b0bed7eSJohn Demme       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2244436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2245436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2246436c6c9cSStella Laurenzo                              &PyMlirContext::getCapsule)
2247436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2248436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
2249436c6c9cSStella Laurenzo       .def("__exit__", &PyMlirContext::contextExit)
2250436c6c9cSStella Laurenzo       .def_property_readonly_static(
2251436c6c9cSStella Laurenzo           "current",
2252436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2253436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
2254436c6c9cSStella Laurenzo             if (!context)
2255436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Context");
2256436c6c9cSStella Laurenzo             return context;
2257436c6c9cSStella Laurenzo           },
2258436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
2259436c6c9cSStella Laurenzo       .def_property_readonly(
2260436c6c9cSStella Laurenzo           "dialects",
2261436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2262436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
2263436c6c9cSStella Laurenzo       .def_property_readonly(
2264436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2265436c6c9cSStella Laurenzo           "Alias for 'dialect'")
2266436c6c9cSStella Laurenzo       .def(
2267436c6c9cSStella Laurenzo           "get_dialect_descriptor",
2268436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
2269436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
2270436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
2271436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
2272436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2273436c6c9cSStella Laurenzo                                Twine("Dialect '") + name + "' not found");
2274436c6c9cSStella Laurenzo             }
2275436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
2276436c6c9cSStella Laurenzo           },
2277a6e7d024SStella Laurenzo           py::arg("dialect_name"),
2278436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
2279436c6c9cSStella Laurenzo       .def_property(
2280436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
2281436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
2282436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
2283436c6c9cSStella Laurenzo           },
2284436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
2285436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
22869a9214faSStella Laurenzo           })
22877ee25bc5SStella Laurenzo       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
22887ee25bc5SStella Laurenzo            py::arg("callback"),
22897ee25bc5SStella Laurenzo            "Attaches a diagnostic handler that will receive callbacks")
2290a6e7d024SStella Laurenzo       .def(
2291a6e7d024SStella Laurenzo           "enable_multithreading",
2292caa159f0SNicolas Vasilache           [](PyMlirContext &self, bool enable) {
2293caa159f0SNicolas Vasilache             mlirContextEnableMultithreading(self.get(), enable);
2294a6e7d024SStella Laurenzo           },
2295a6e7d024SStella Laurenzo           py::arg("enable"))
2296a6e7d024SStella Laurenzo       .def(
2297a6e7d024SStella Laurenzo           "is_registered_operation",
22989a9214faSStella Laurenzo           [](PyMlirContext &self, std::string &name) {
22999a9214faSStella Laurenzo             return mlirContextIsRegisteredOperation(
23009a9214faSStella Laurenzo                 self.get(), MlirStringRef{name.data(), name.size()});
2301a6e7d024SStella Laurenzo           },
23025e83a5b4SStella Laurenzo           py::arg("operation_name"))
23035e83a5b4SStella Laurenzo       .def(
23045e83a5b4SStella Laurenzo           "append_dialect_registry",
23055e83a5b4SStella Laurenzo           [](PyMlirContext &self, PyDialectRegistry &registry) {
23065e83a5b4SStella Laurenzo             mlirContextAppendDialectRegistry(self.get(), registry);
23075e83a5b4SStella Laurenzo           },
23085e83a5b4SStella Laurenzo           py::arg("registry"))
23095e83a5b4SStella Laurenzo       .def("load_all_available_dialects", [](PyMlirContext &self) {
23105e83a5b4SStella Laurenzo         mlirContextLoadAllAvailableDialects(self.get());
23115e83a5b4SStella Laurenzo       });
2312436c6c9cSStella Laurenzo 
2313436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2314436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
2315436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2316f05ff4f7SStella Laurenzo   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2317436c6c9cSStella Laurenzo       .def_property_readonly("namespace",
2318436c6c9cSStella Laurenzo                              [](PyDialectDescriptor &self) {
2319436c6c9cSStella Laurenzo                                MlirStringRef ns =
2320436c6c9cSStella Laurenzo                                    mlirDialectGetNamespace(self.get());
2321436c6c9cSStella Laurenzo                                return py::str(ns.data, ns.length);
2322436c6c9cSStella Laurenzo                              })
2323436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
2324436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2325436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
2326436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
2327436c6c9cSStella Laurenzo         repr.append(">");
2328436c6c9cSStella Laurenzo         return repr;
2329436c6c9cSStella Laurenzo       });
2330436c6c9cSStella Laurenzo 
2331436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2332436c6c9cSStella Laurenzo   // Mapping of PyDialects
2333436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2334f05ff4f7SStella Laurenzo   py::class_<PyDialects>(m, "Dialects", py::module_local())
2335436c6c9cSStella Laurenzo       .def("__getitem__",
2336436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
2337436c6c9cSStella Laurenzo              MlirDialect dialect =
2338436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
2339436c6c9cSStella Laurenzo              py::object descriptor =
2340436c6c9cSStella Laurenzo                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2341436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
2342436c6c9cSStella Laurenzo            })
2343436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2344436c6c9cSStella Laurenzo         MlirDialect dialect =
2345436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
2346436c6c9cSStella Laurenzo         py::object descriptor =
2347436c6c9cSStella Laurenzo             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2348436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
2349436c6c9cSStella Laurenzo       });
2350436c6c9cSStella Laurenzo 
2351436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2352436c6c9cSStella Laurenzo   // Mapping of PyDialect
2353436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2354f05ff4f7SStella Laurenzo   py::class_<PyDialect>(m, "Dialect", py::module_local())
2355a6e7d024SStella Laurenzo       .def(py::init<py::object>(), py::arg("descriptor"))
2356436c6c9cSStella Laurenzo       .def_property_readonly(
2357436c6c9cSStella Laurenzo           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2358436c6c9cSStella Laurenzo       .def("__repr__", [](py::object self) {
2359436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
2360436c6c9cSStella Laurenzo         return py::str("<Dialect ") +
2361436c6c9cSStella Laurenzo                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2362436c6c9cSStella Laurenzo                clazz.attr("__module__") + py::str(".") +
2363436c6c9cSStella Laurenzo                clazz.attr("__name__") + py::str(")>");
2364436c6c9cSStella Laurenzo       });
2365436c6c9cSStella Laurenzo 
2366436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
23675e83a5b4SStella Laurenzo   // Mapping of PyDialectRegistry
23685e83a5b4SStella Laurenzo   //----------------------------------------------------------------------------
23695e83a5b4SStella Laurenzo   py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
23705e83a5b4SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
23715e83a5b4SStella Laurenzo                              &PyDialectRegistry::getCapsule)
23725e83a5b4SStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
23735e83a5b4SStella Laurenzo       .def(py::init<>());
23745e83a5b4SStella Laurenzo 
23755e83a5b4SStella Laurenzo   //----------------------------------------------------------------------------
2376436c6c9cSStella Laurenzo   // Mapping of Location
2377436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2378f05ff4f7SStella Laurenzo   py::class_<PyLocation>(m, "Location", py::module_local())
2379436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2380436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2381436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
2382436c6c9cSStella Laurenzo       .def("__exit__", &PyLocation::contextExit)
2383436c6c9cSStella Laurenzo       .def("__eq__",
2384436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
2385436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
2386436c6c9cSStella Laurenzo            })
2387436c6c9cSStella Laurenzo       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2388436c6c9cSStella Laurenzo       .def_property_readonly_static(
2389436c6c9cSStella Laurenzo           "current",
2390436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2391436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
2392436c6c9cSStella Laurenzo             if (!loc)
2393436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Location");
2394436c6c9cSStella Laurenzo             return loc;
2395436c6c9cSStella Laurenzo           },
2396436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
2397436c6c9cSStella Laurenzo       .def_static(
2398436c6c9cSStella Laurenzo           "unknown",
2399436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
2400436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
2401436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
2402436c6c9cSStella Laurenzo           },
2403436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
2404436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
2405436c6c9cSStella Laurenzo       .def_static(
2406e67cbbefSJacques Pienaar           "callsite",
2407e67cbbefSJacques Pienaar           [](PyLocation callee, const std::vector<PyLocation> &frames,
2408e67cbbefSJacques Pienaar              DefaultingPyMlirContext context) {
2409e67cbbefSJacques Pienaar             if (frames.empty())
2410e67cbbefSJacques Pienaar               throw py::value_error("No caller frames provided");
2411e67cbbefSJacques Pienaar             MlirLocation caller = frames.back().get();
2412e2f16be5SMehdi Amini             for (const PyLocation &frame :
2413e67cbbefSJacques Pienaar                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2414e67cbbefSJacques Pienaar               caller = mlirLocationCallSiteGet(frame.get(), caller);
2415e67cbbefSJacques Pienaar             return PyLocation(context->getRef(),
2416e67cbbefSJacques Pienaar                               mlirLocationCallSiteGet(callee.get(), caller));
2417e67cbbefSJacques Pienaar           },
2418e67cbbefSJacques Pienaar           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2419e67cbbefSJacques Pienaar           kContextGetCallSiteLocationDocstring)
2420e67cbbefSJacques Pienaar       .def_static(
2421436c6c9cSStella Laurenzo           "file",
2422436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
2423436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
2424436c6c9cSStella Laurenzo             return PyLocation(
2425436c6c9cSStella Laurenzo                 context->getRef(),
2426436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
2427436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
2428436c6c9cSStella Laurenzo           },
2429436c6c9cSStella Laurenzo           py::arg("filename"), py::arg("line"), py::arg("col"),
2430436c6c9cSStella Laurenzo           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
243104d76d36SJacques Pienaar       .def_static(
24321ab3efacSJacques Pienaar           "fused",
24337ee25bc5SStella Laurenzo           [](const std::vector<PyLocation> &pyLocations,
24347ee25bc5SStella Laurenzo              llvm::Optional<PyAttribute> metadata,
24351ab3efacSJacques Pienaar              DefaultingPyMlirContext context) {
24361ab3efacSJacques Pienaar             llvm::SmallVector<MlirLocation, 4> locations;
24371ab3efacSJacques Pienaar             locations.reserve(pyLocations.size());
24381ab3efacSJacques Pienaar             for (auto &pyLocation : pyLocations)
24391ab3efacSJacques Pienaar               locations.push_back(pyLocation.get());
24401ab3efacSJacques Pienaar             MlirLocation location = mlirLocationFusedGet(
24411ab3efacSJacques Pienaar                 context->get(), locations.size(), locations.data(),
24421ab3efacSJacques Pienaar                 metadata ? metadata->get() : MlirAttribute{0});
24431ab3efacSJacques Pienaar             return PyLocation(context->getRef(), location);
24441ab3efacSJacques Pienaar           },
24451ab3efacSJacques Pienaar           py::arg("locations"), py::arg("metadata") = py::none(),
24461ab3efacSJacques Pienaar           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
24471ab3efacSJacques Pienaar       .def_static(
244804d76d36SJacques Pienaar           "name",
244904d76d36SJacques Pienaar           [](std::string name, llvm::Optional<PyLocation> childLoc,
245004d76d36SJacques Pienaar              DefaultingPyMlirContext context) {
245104d76d36SJacques Pienaar             return PyLocation(
245204d76d36SJacques Pienaar                 context->getRef(),
245304d76d36SJacques Pienaar                 mlirLocationNameGet(
245404d76d36SJacques Pienaar                     context->get(), toMlirStringRef(name),
245504d76d36SJacques Pienaar                     childLoc ? childLoc->get()
245604d76d36SJacques Pienaar                              : mlirLocationUnknownGet(context->get())));
245704d76d36SJacques Pienaar           },
245804d76d36SJacques Pienaar           py::arg("name"), py::arg("childLoc") = py::none(),
245904d76d36SJacques Pienaar           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2460436c6c9cSStella Laurenzo       .def_property_readonly(
2461436c6c9cSStella Laurenzo           "context",
2462436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
2463436c6c9cSStella Laurenzo           "Context that owns the Location")
24647ee25bc5SStella Laurenzo       .def(
24657ee25bc5SStella Laurenzo           "emit_error",
24667ee25bc5SStella Laurenzo           [](PyLocation &self, std::string message) {
24677ee25bc5SStella Laurenzo             mlirEmitError(self, message.c_str());
24687ee25bc5SStella Laurenzo           },
24697ee25bc5SStella Laurenzo           py::arg("message"), "Emits an error at this location")
2470436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
2471436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2472436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
2473436c6c9cSStella Laurenzo                           printAccum.getUserData());
2474436c6c9cSStella Laurenzo         return printAccum.join();
2475436c6c9cSStella Laurenzo       });
2476436c6c9cSStella Laurenzo 
2477436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2478436c6c9cSStella Laurenzo   // Mapping of Module
2479436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2480f05ff4f7SStella Laurenzo   py::class_<PyModule>(m, "Module", py::module_local())
2481436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2482436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2483436c6c9cSStella Laurenzo       .def_static(
2484436c6c9cSStella Laurenzo           "parse",
2485436c6c9cSStella Laurenzo           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2486436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
2487436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
2488436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2489436c6c9cSStella Laurenzo             // in C API.
2490436c6c9cSStella Laurenzo             if (mlirModuleIsNull(module)) {
2491436c6c9cSStella Laurenzo               throw SetPyError(
2492436c6c9cSStella Laurenzo                   PyExc_ValueError,
2493436c6c9cSStella Laurenzo                   "Unable to parse module assembly (see diagnostics)");
2494436c6c9cSStella Laurenzo             }
2495436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2496436c6c9cSStella Laurenzo           },
2497436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2498436c6c9cSStella Laurenzo           kModuleParseDocstring)
2499436c6c9cSStella Laurenzo       .def_static(
2500436c6c9cSStella Laurenzo           "create",
2501436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
2502436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
2503436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2504436c6c9cSStella Laurenzo           },
2505436c6c9cSStella Laurenzo           py::arg("loc") = py::none(), "Creates an empty module")
2506436c6c9cSStella Laurenzo       .def_property_readonly(
2507436c6c9cSStella Laurenzo           "context",
2508436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
2509436c6c9cSStella Laurenzo           "Context that created the Module")
2510436c6c9cSStella Laurenzo       .def_property_readonly(
2511436c6c9cSStella Laurenzo           "operation",
2512436c6c9cSStella Laurenzo           [](PyModule &self) {
2513436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
2514436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
2515436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
2516436c6c9cSStella Laurenzo                 .releaseObject();
2517436c6c9cSStella Laurenzo           },
2518436c6c9cSStella Laurenzo           "Accesses the module as an operation")
2519436c6c9cSStella Laurenzo       .def_property_readonly(
2520436c6c9cSStella Laurenzo           "body",
2521436c6c9cSStella Laurenzo           [](PyModule &self) {
252202b6fb21SMehdi Amini             PyOperationRef moduleOp = PyOperation::forOperation(
2523436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
2524436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
252502b6fb21SMehdi Amini             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2526436c6c9cSStella Laurenzo             return returnBlock;
2527436c6c9cSStella Laurenzo           },
2528436c6c9cSStella Laurenzo           "Return the block for this module")
2529436c6c9cSStella Laurenzo       .def(
2530436c6c9cSStella Laurenzo           "dump",
2531436c6c9cSStella Laurenzo           [](PyModule &self) {
2532436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
2533436c6c9cSStella Laurenzo           },
2534436c6c9cSStella Laurenzo           kDumpDocstring)
2535436c6c9cSStella Laurenzo       .def(
2536436c6c9cSStella Laurenzo           "__str__",
2537ace1d0adSStella Laurenzo           [](py::object self) {
2538ace1d0adSStella Laurenzo             // Defer to the operation's __str__.
2539ace1d0adSStella Laurenzo             return self.attr("operation").attr("__str__")();
2540436c6c9cSStella Laurenzo           },
2541436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
2542436c6c9cSStella Laurenzo 
2543436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2544436c6c9cSStella Laurenzo   // Mapping of Operation.
2545436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2546f05ff4f7SStella Laurenzo   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
25471fb2e842SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
25481fb2e842SStella Laurenzo                              [](PyOperationBase &self) {
25491fb2e842SStella Laurenzo                                return self.getOperation().getCapsule();
25501fb2e842SStella Laurenzo                              })
2551436c6c9cSStella Laurenzo       .def("__eq__",
2552436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
2553436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
2554436c6c9cSStella Laurenzo            })
2555436c6c9cSStella Laurenzo       .def("__eq__",
2556436c6c9cSStella Laurenzo            [](PyOperationBase &self, py::object other) { return false; })
2557f78fe0b7Srkayaith       .def("__hash__",
2558f78fe0b7Srkayaith            [](PyOperationBase &self) {
2559f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2560f78fe0b7Srkayaith            })
2561436c6c9cSStella Laurenzo       .def_property_readonly("attributes",
2562436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2563436c6c9cSStella Laurenzo                                return PyOpAttributeMap(
2564436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2565436c6c9cSStella Laurenzo                              })
2566436c6c9cSStella Laurenzo       .def_property_readonly("operands",
2567436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2568436c6c9cSStella Laurenzo                                return PyOpOperandList(
2569436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2570436c6c9cSStella Laurenzo                              })
2571436c6c9cSStella Laurenzo       .def_property_readonly("regions",
2572436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2573436c6c9cSStella Laurenzo                                return PyRegionList(
2574436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2575436c6c9cSStella Laurenzo                              })
2576436c6c9cSStella Laurenzo       .def_property_readonly(
2577436c6c9cSStella Laurenzo           "results",
2578436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2579436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
2580436c6c9cSStella Laurenzo           },
2581436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
2582436c6c9cSStella Laurenzo       .def_property_readonly(
2583436c6c9cSStella Laurenzo           "result",
2584436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2585436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
2586436c6c9cSStella Laurenzo             auto numResults = mlirOperationGetNumResults(operation);
2587436c6c9cSStella Laurenzo             if (numResults != 1) {
2588436c6c9cSStella Laurenzo               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2589436c6c9cSStella Laurenzo               throw SetPyError(
2590436c6c9cSStella Laurenzo                   PyExc_ValueError,
2591436c6c9cSStella Laurenzo                   Twine("Cannot call .result on operation ") +
2592436c6c9cSStella Laurenzo                       StringRef(name.data, name.length) + " which has " +
2593436c6c9cSStella Laurenzo                       Twine(numResults) +
2594436c6c9cSStella Laurenzo                       " results (it is only valid for operations with a "
2595436c6c9cSStella Laurenzo                       "single result)");
2596436c6c9cSStella Laurenzo             }
2597436c6c9cSStella Laurenzo             return PyOpResult(operation.getRef(),
2598436c6c9cSStella Laurenzo                               mlirOperationGetResult(operation, 0));
2599436c6c9cSStella Laurenzo           },
2600436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
2601436c6c9cSStella Laurenzo           "otherwise).")
2602d5429a13Srkayaith       .def_property_readonly(
2603d5429a13Srkayaith           "location",
2604d5429a13Srkayaith           [](PyOperationBase &self) {
2605d5429a13Srkayaith             PyOperation &operation = self.getOperation();
2606d5429a13Srkayaith             return PyLocation(operation.getContext(),
2607d5429a13Srkayaith                               mlirOperationGetLocation(operation.get()));
2608d5429a13Srkayaith           },
2609d5429a13Srkayaith           "Returns the source location the operation was defined or derived "
2610d5429a13Srkayaith           "from.")
2611436c6c9cSStella Laurenzo       .def(
2612436c6c9cSStella Laurenzo           "__str__",
2613436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2614436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
2615436c6c9cSStella Laurenzo                                /*largeElementsLimit=*/llvm::None,
2616436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
2617436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
2618436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
2619ace1d0adSStella Laurenzo                                /*useLocalScope=*/false,
2620ace1d0adSStella Laurenzo                                /*assumeVerified=*/false);
2621436c6c9cSStella Laurenzo           },
2622436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
2623436c6c9cSStella Laurenzo       .def("print", &PyOperationBase::print,
2624436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
2625436c6c9cSStella Laurenzo            py::arg("file") = py::none(), py::arg("binary") = false,
2626436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2627436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2628436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2629436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2630ace1d0adSStella Laurenzo            py::arg("use_local_scope") = false,
2631ace1d0adSStella Laurenzo            py::arg("assume_verified") = false, kOperationPrintDocstring)
2632436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
2633436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
2634436c6c9cSStella Laurenzo            py::arg("binary") = false,
2635436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2636436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2637436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2638436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2639ace1d0adSStella Laurenzo            py::arg("use_local_scope") = false,
2640ace1d0adSStella Laurenzo            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2641436c6c9cSStella Laurenzo       .def(
2642436c6c9cSStella Laurenzo           "verify",
2643436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2644436c6c9cSStella Laurenzo             return mlirOperationVerify(self.getOperation());
2645436c6c9cSStella Laurenzo           },
2646436c6c9cSStella Laurenzo           "Verify the operation and return true if it passes, false if it "
264724685aaeSAlex Zinenko           "fails.")
264824685aaeSAlex Zinenko       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
264924685aaeSAlex Zinenko            "Puts self immediately after the other operation in its parent "
265024685aaeSAlex Zinenko            "block.")
265124685aaeSAlex Zinenko       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
265224685aaeSAlex Zinenko            "Puts self immediately before the other operation in its parent "
265324685aaeSAlex Zinenko            "block.")
265424685aaeSAlex Zinenko       .def(
265524685aaeSAlex Zinenko           "detach_from_parent",
265624685aaeSAlex Zinenko           [](PyOperationBase &self) {
265724685aaeSAlex Zinenko             PyOperation &operation = self.getOperation();
265824685aaeSAlex Zinenko             operation.checkValid();
265924685aaeSAlex Zinenko             if (!operation.isAttached())
266024685aaeSAlex Zinenko               throw py::value_error("Detached operation has no parent.");
266124685aaeSAlex Zinenko 
266224685aaeSAlex Zinenko             operation.detachFromParent();
266324685aaeSAlex Zinenko             return operation.createOpView();
266424685aaeSAlex Zinenko           },
266524685aaeSAlex Zinenko           "Detaches the operation from its parent block.");
2666436c6c9cSStella Laurenzo 
2667f05ff4f7SStella Laurenzo   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2668436c6c9cSStella Laurenzo       .def_static("create", &PyOperation::create, py::arg("name"),
2669436c6c9cSStella Laurenzo                   py::arg("results") = py::none(),
2670436c6c9cSStella Laurenzo                   py::arg("operands") = py::none(),
2671436c6c9cSStella Laurenzo                   py::arg("attributes") = py::none(),
2672436c6c9cSStella Laurenzo                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2673436c6c9cSStella Laurenzo                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2674436c6c9cSStella Laurenzo                   kOperationCreateDocstring)
2675c65bb760SJohn Demme       .def_property_readonly("parent",
26761689dadeSJohn Demme                              [](PyOperation &self) -> py::object {
26771689dadeSJohn Demme                                auto parent = self.getParentOperation();
26781689dadeSJohn Demme                                if (parent)
26791689dadeSJohn Demme                                  return parent->getObject();
26801689dadeSJohn Demme                                return py::none();
2681c65bb760SJohn Demme                              })
268249745f87SMike Urbach       .def("erase", &PyOperation::erase)
2683774818c0SDominik Grewe       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
26840126e906SJohn Demme       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
26850126e906SJohn Demme                              &PyOperation::getCapsule)
26860126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2687436c6c9cSStella Laurenzo       .def_property_readonly("name",
2688436c6c9cSStella Laurenzo                              [](PyOperation &self) {
268949745f87SMike Urbach                                self.checkValid();
2690436c6c9cSStella Laurenzo                                MlirOperation operation = self.get();
2691436c6c9cSStella Laurenzo                                MlirStringRef name = mlirIdentifierStr(
2692436c6c9cSStella Laurenzo                                    mlirOperationGetName(operation));
2693436c6c9cSStella Laurenzo                                return py::str(name.data, name.length);
2694436c6c9cSStella Laurenzo                              })
2695436c6c9cSStella Laurenzo       .def_property_readonly(
2696436c6c9cSStella Laurenzo           "context",
269749745f87SMike Urbach           [](PyOperation &self) {
269849745f87SMike Urbach             self.checkValid();
269949745f87SMike Urbach             return self.getContext().getObject();
270049745f87SMike Urbach           },
2701436c6c9cSStella Laurenzo           "Context that owns the Operation")
2702436c6c9cSStella Laurenzo       .def_property_readonly("opview", &PyOperation::createOpView);
2703436c6c9cSStella Laurenzo 
2704436c6c9cSStella Laurenzo   auto opViewClass =
2705f05ff4f7SStella Laurenzo       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2706a6e7d024SStella Laurenzo           .def(py::init<py::object>(), py::arg("operation"))
2707436c6c9cSStella Laurenzo           .def_property_readonly("operation", &PyOpView::getOperationObject)
2708436c6c9cSStella Laurenzo           .def_property_readonly(
2709436c6c9cSStella Laurenzo               "context",
2710436c6c9cSStella Laurenzo               [](PyOpView &self) {
2711436c6c9cSStella Laurenzo                 return self.getOperation().getContext().getObject();
2712436c6c9cSStella Laurenzo               },
2713436c6c9cSStella Laurenzo               "Context that owns the Operation")
2714436c6c9cSStella Laurenzo           .def("__str__", [](PyOpView &self) {
2715436c6c9cSStella Laurenzo             return py::str(self.getOperationObject());
2716436c6c9cSStella Laurenzo           });
2717436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2718436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2719436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2720436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
2721436c6c9cSStella Laurenzo       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2722436c6c9cSStella Laurenzo       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2723436c6c9cSStella Laurenzo       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2724436c6c9cSStella Laurenzo       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2725436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
2726436c6c9cSStella Laurenzo 
2727436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2728436c6c9cSStella Laurenzo   // Mapping of PyRegion.
2729436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2730f05ff4f7SStella Laurenzo   py::class_<PyRegion>(m, "Region", py::module_local())
2731436c6c9cSStella Laurenzo       .def_property_readonly(
2732436c6c9cSStella Laurenzo           "blocks",
2733436c6c9cSStella Laurenzo           [](PyRegion &self) {
2734436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
2735436c6c9cSStella Laurenzo           },
2736436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
273778f2dae0SAlex Zinenko       .def_property_readonly(
273878f2dae0SAlex Zinenko           "owner",
273978f2dae0SAlex Zinenko           [](PyRegion &self) {
274078f2dae0SAlex Zinenko             return self.getParentOperation()->createOpView();
274178f2dae0SAlex Zinenko           },
274278f2dae0SAlex Zinenko           "Returns the operation owning this region.")
2743436c6c9cSStella Laurenzo       .def(
2744436c6c9cSStella Laurenzo           "__iter__",
2745436c6c9cSStella Laurenzo           [](PyRegion &self) {
2746436c6c9cSStella Laurenzo             self.checkValid();
2747436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2748436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
2749436c6c9cSStella Laurenzo           },
2750436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
2751436c6c9cSStella Laurenzo       .def("__eq__",
2752436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
2753436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2754436c6c9cSStella Laurenzo            })
2755436c6c9cSStella Laurenzo       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2756436c6c9cSStella Laurenzo 
2757436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2758436c6c9cSStella Laurenzo   // Mapping of PyBlock.
2759436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2760f05ff4f7SStella Laurenzo   py::class_<PyBlock>(m, "Block", py::module_local())
2761436c6c9cSStella Laurenzo       .def_property_readonly(
276296fbd5cdSJohn Demme           "owner",
276396fbd5cdSJohn Demme           [](PyBlock &self) {
276496fbd5cdSJohn Demme             return self.getParentOperation()->createOpView();
276596fbd5cdSJohn Demme           },
276696fbd5cdSJohn Demme           "Returns the owning operation of this block.")
276796fbd5cdSJohn Demme       .def_property_readonly(
27688e6c55c9SStella Laurenzo           "region",
27698e6c55c9SStella Laurenzo           [](PyBlock &self) {
27708e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
27718e6c55c9SStella Laurenzo             return PyRegion(self.getParentOperation(), region);
27728e6c55c9SStella Laurenzo           },
27738e6c55c9SStella Laurenzo           "Returns the owning region of this block.")
27748e6c55c9SStella Laurenzo       .def_property_readonly(
2775436c6c9cSStella Laurenzo           "arguments",
2776436c6c9cSStella Laurenzo           [](PyBlock &self) {
2777436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
2778436c6c9cSStella Laurenzo           },
2779436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
2780436c6c9cSStella Laurenzo       .def_property_readonly(
2781436c6c9cSStella Laurenzo           "operations",
2782436c6c9cSStella Laurenzo           [](PyBlock &self) {
2783436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
2784436c6c9cSStella Laurenzo           },
2785436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
278678f2dae0SAlex Zinenko       .def_static(
278778f2dae0SAlex Zinenko           "create_at_start",
278878f2dae0SAlex Zinenko           [](PyRegion &parent, py::list pyArgTypes) {
278978f2dae0SAlex Zinenko             parent.checkValid();
279078f2dae0SAlex Zinenko             llvm::SmallVector<MlirType, 4> argTypes;
2791e084679fSRiver Riddle             llvm::SmallVector<MlirLocation, 4> argLocs;
279278f2dae0SAlex Zinenko             argTypes.reserve(pyArgTypes.size());
2793e084679fSRiver Riddle             argLocs.reserve(pyArgTypes.size());
279478f2dae0SAlex Zinenko             for (auto &pyArg : pyArgTypes) {
279578f2dae0SAlex Zinenko               argTypes.push_back(pyArg.cast<PyType &>());
2796e084679fSRiver Riddle               // TODO: Pass in a proper location here.
2797e084679fSRiver Riddle               argLocs.push_back(
2798e084679fSRiver Riddle                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
279978f2dae0SAlex Zinenko             }
280078f2dae0SAlex Zinenko 
2801e084679fSRiver Riddle             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2802e084679fSRiver Riddle                                               argLocs.data());
280378f2dae0SAlex Zinenko             mlirRegionInsertOwnedBlock(parent, 0, block);
280478f2dae0SAlex Zinenko             return PyBlock(parent.getParentOperation(), block);
280578f2dae0SAlex Zinenko           },
2806a6e7d024SStella Laurenzo           py::arg("parent"), py::arg("arg_types") = py::list(),
280778f2dae0SAlex Zinenko           "Creates and returns a new Block at the beginning of the given "
280878f2dae0SAlex Zinenko           "region (with given argument types).")
2809436c6c9cSStella Laurenzo       .def(
28108d8738f6SJohn Demme           "append_to",
28118d8738f6SJohn Demme           [](PyBlock &self, PyRegion &region) {
28128d8738f6SJohn Demme             MlirBlock b = self.get();
28138d8738f6SJohn Demme             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
28148d8738f6SJohn Demme               mlirBlockDetach(b);
28158d8738f6SJohn Demme             mlirRegionAppendOwnedBlock(region.get(), b);
28168d8738f6SJohn Demme           },
28178d8738f6SJohn Demme           "Append this block to a region, transferring ownership if necessary")
28188d8738f6SJohn Demme       .def(
28198e6c55c9SStella Laurenzo           "create_before",
28208e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
28218e6c55c9SStella Laurenzo             self.checkValid();
28228e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
2823e084679fSRiver Riddle             llvm::SmallVector<MlirLocation, 4> argLocs;
28248e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
2825e084679fSRiver Riddle             argLocs.reserve(pyArgTypes.size());
28268e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
28278e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
2828e084679fSRiver Riddle               // TODO: Pass in a proper location here.
2829e084679fSRiver Riddle               argLocs.push_back(
2830e084679fSRiver Riddle                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
28318e6c55c9SStella Laurenzo             }
28328e6c55c9SStella Laurenzo 
2833e084679fSRiver Riddle             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2834e084679fSRiver Riddle                                               argLocs.data());
28358e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
28368e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
28378e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
28388e6c55c9SStella Laurenzo           },
28398e6c55c9SStella Laurenzo           "Creates and returns a new Block before this block "
28408e6c55c9SStella Laurenzo           "(with given argument types).")
28418e6c55c9SStella Laurenzo       .def(
28428e6c55c9SStella Laurenzo           "create_after",
28438e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
28448e6c55c9SStella Laurenzo             self.checkValid();
28458e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
2846e084679fSRiver Riddle             llvm::SmallVector<MlirLocation, 4> argLocs;
28478e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
2848e084679fSRiver Riddle             argLocs.reserve(pyArgTypes.size());
28498e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
28508e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
28518e6c55c9SStella Laurenzo 
2852e084679fSRiver Riddle               // TODO: Pass in a proper location here.
2853e084679fSRiver Riddle               argLocs.push_back(
2854e084679fSRiver Riddle                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2855e084679fSRiver Riddle             }
2856e084679fSRiver Riddle             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2857e084679fSRiver Riddle                                               argLocs.data());
28588e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
28598e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
28608e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
28618e6c55c9SStella Laurenzo           },
28628e6c55c9SStella Laurenzo           "Creates and returns a new Block after this block "
28638e6c55c9SStella Laurenzo           "(with given argument types).")
28648e6c55c9SStella Laurenzo       .def(
2865436c6c9cSStella Laurenzo           "__iter__",
2866436c6c9cSStella Laurenzo           [](PyBlock &self) {
2867436c6c9cSStella Laurenzo             self.checkValid();
2868436c6c9cSStella Laurenzo             MlirOperation firstOperation =
2869436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
2870436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
2871436c6c9cSStella Laurenzo                                        firstOperation);
2872436c6c9cSStella Laurenzo           },
2873436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
2874436c6c9cSStella Laurenzo       .def("__eq__",
2875436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
2876436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2877436c6c9cSStella Laurenzo            })
2878436c6c9cSStella Laurenzo       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2879436c6c9cSStella Laurenzo       .def(
2880436c6c9cSStella Laurenzo           "__str__",
2881436c6c9cSStella Laurenzo           [](PyBlock &self) {
2882436c6c9cSStella Laurenzo             self.checkValid();
2883436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2884436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
2885436c6c9cSStella Laurenzo                            printAccum.getUserData());
2886436c6c9cSStella Laurenzo             return printAccum.join();
2887436c6c9cSStella Laurenzo           },
288824685aaeSAlex Zinenko           "Returns the assembly form of the block.")
288924685aaeSAlex Zinenko       .def(
289024685aaeSAlex Zinenko           "append",
289124685aaeSAlex Zinenko           [](PyBlock &self, PyOperationBase &operation) {
289224685aaeSAlex Zinenko             if (operation.getOperation().isAttached())
289324685aaeSAlex Zinenko               operation.getOperation().detachFromParent();
289424685aaeSAlex Zinenko 
289524685aaeSAlex Zinenko             MlirOperation mlirOperation = operation.getOperation().get();
289624685aaeSAlex Zinenko             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
289724685aaeSAlex Zinenko             operation.getOperation().setAttached(
289824685aaeSAlex Zinenko                 self.getParentOperation().getObject());
289924685aaeSAlex Zinenko           },
2900a6e7d024SStella Laurenzo           py::arg("operation"),
290124685aaeSAlex Zinenko           "Appends an operation to this block. If the operation is currently "
290224685aaeSAlex Zinenko           "in another block, it will be moved.");
2903436c6c9cSStella Laurenzo 
2904436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2905436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
2906436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2907436c6c9cSStella Laurenzo 
2908f05ff4f7SStella Laurenzo   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2909436c6c9cSStella Laurenzo       .def(py::init<PyBlock &>(), py::arg("block"),
2910436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
2911436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
2912436c6c9cSStella Laurenzo       .def("__exit__", &PyInsertionPoint::contextExit)
2913436c6c9cSStella Laurenzo       .def_property_readonly_static(
2914436c6c9cSStella Laurenzo           "current",
2915436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2916436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2917436c6c9cSStella Laurenzo             if (!ip)
2918436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2919436c6c9cSStella Laurenzo             return ip;
2920436c6c9cSStella Laurenzo           },
2921436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
2922436c6c9cSStella Laurenzo           "ValueError if none has been set")
2923436c6c9cSStella Laurenzo       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2924436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
2925436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2926436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts at the beginning of the block.")
2927436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2928436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts before the block terminator.")
2929436c6c9cSStella Laurenzo       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
29308e6c55c9SStella Laurenzo            "Inserts an operation.")
29318e6c55c9SStella Laurenzo       .def_property_readonly(
29328e6c55c9SStella Laurenzo           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
29338e6c55c9SStella Laurenzo           "Returns the block that this InsertionPoint points to.");
2934436c6c9cSStella Laurenzo 
2935436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2936436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
2937436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2938f05ff4f7SStella Laurenzo   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2939b57d6fe4SStella Laurenzo       // Delegate to the PyAttribute copy constructor, which will also lifetime
2940b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirAttribute.
2941b57d6fe4SStella Laurenzo       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2942b57d6fe4SStella Laurenzo            "Casts the passed attribute to the generic Attribute")
2943436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2944436c6c9cSStella Laurenzo                              &PyAttribute::getCapsule)
2945436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2946436c6c9cSStella Laurenzo       .def_static(
2947436c6c9cSStella Laurenzo           "parse",
2948436c6c9cSStella Laurenzo           [](std::string attrSpec, DefaultingPyMlirContext context) {
2949436c6c9cSStella Laurenzo             MlirAttribute type = mlirAttributeParseGet(
2950436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
2951436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2952436c6c9cSStella Laurenzo             // in C API.
2953436c6c9cSStella Laurenzo             if (mlirAttributeIsNull(type)) {
2954436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2955436c6c9cSStella Laurenzo                                Twine("Unable to parse attribute: '") +
2956436c6c9cSStella Laurenzo                                    attrSpec + "'");
2957436c6c9cSStella Laurenzo             }
2958436c6c9cSStella Laurenzo             return PyAttribute(context->getRef(), type);
2959436c6c9cSStella Laurenzo           },
2960436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2961436c6c9cSStella Laurenzo           "Parses an attribute from an assembly form")
2962436c6c9cSStella Laurenzo       .def_property_readonly(
2963436c6c9cSStella Laurenzo           "context",
2964436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
2965436c6c9cSStella Laurenzo           "Context that owns the Attribute")
2966436c6c9cSStella Laurenzo       .def_property_readonly("type",
2967436c6c9cSStella Laurenzo                              [](PyAttribute &self) {
2968436c6c9cSStella Laurenzo                                return PyType(self.getContext()->getRef(),
2969436c6c9cSStella Laurenzo                                              mlirAttributeGetType(self));
2970436c6c9cSStella Laurenzo                              })
2971436c6c9cSStella Laurenzo       .def(
2972436c6c9cSStella Laurenzo           "get_named",
2973436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
2974436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
2975436c6c9cSStella Laurenzo           },
2976436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2977436c6c9cSStella Laurenzo       .def("__eq__",
2978436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2979436c6c9cSStella Laurenzo       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2980f78fe0b7Srkayaith       .def("__hash__",
2981f78fe0b7Srkayaith            [](PyAttribute &self) {
2982f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2983f78fe0b7Srkayaith            })
2984436c6c9cSStella Laurenzo       .def(
2985436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2986436c6c9cSStella Laurenzo           kDumpDocstring)
2987436c6c9cSStella Laurenzo       .def(
2988436c6c9cSStella Laurenzo           "__str__",
2989436c6c9cSStella Laurenzo           [](PyAttribute &self) {
2990436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2991436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
2992436c6c9cSStella Laurenzo                                printAccum.getUserData());
2993436c6c9cSStella Laurenzo             return printAccum.join();
2994436c6c9cSStella Laurenzo           },
2995436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
2996436c6c9cSStella Laurenzo       .def("__repr__", [](PyAttribute &self) {
2997436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2998436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2999436c6c9cSStella Laurenzo         // However, attribute values are generally considered useful and are
3000436c6c9cSStella Laurenzo         // printed. This may need to be re-evaluated if debug dumps end up
3001436c6c9cSStella Laurenzo         // being excessive.
3002436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
3003436c6c9cSStella Laurenzo         printAccum.parts.append("Attribute(");
3004436c6c9cSStella Laurenzo         mlirAttributePrint(self, printAccum.getCallback(),
3005436c6c9cSStella Laurenzo                            printAccum.getUserData());
3006436c6c9cSStella Laurenzo         printAccum.parts.append(")");
3007436c6c9cSStella Laurenzo         return printAccum.join();
3008436c6c9cSStella Laurenzo       });
3009436c6c9cSStella Laurenzo 
3010436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3011436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
3012436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3013f05ff4f7SStella Laurenzo   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3014436c6c9cSStella Laurenzo       .def("__repr__",
3015436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
3016436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
3017436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
3018436c6c9cSStella Laurenzo              printAccum.parts.append(
3019120591e1SRiver Riddle                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3020120591e1SRiver Riddle                          mlirIdentifierStr(self.namedAttr.name).length));
3021436c6c9cSStella Laurenzo              printAccum.parts.append("=");
3022436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
3023436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
3024436c6c9cSStella Laurenzo                                 printAccum.getUserData());
3025436c6c9cSStella Laurenzo              printAccum.parts.append(")");
3026436c6c9cSStella Laurenzo              return printAccum.join();
3027436c6c9cSStella Laurenzo            })
3028436c6c9cSStella Laurenzo       .def_property_readonly(
3029436c6c9cSStella Laurenzo           "name",
3030436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
3031436c6c9cSStella Laurenzo             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3032436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
3033436c6c9cSStella Laurenzo           },
3034436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
3035436c6c9cSStella Laurenzo       .def_property_readonly(
3036436c6c9cSStella Laurenzo           "attr",
3037436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
3038436c6c9cSStella Laurenzo             // TODO: When named attribute is removed/refactored, also remove
3039436c6c9cSStella Laurenzo             // this constructor (it does an inefficient table lookup).
3040436c6c9cSStella Laurenzo             auto contextRef = PyMlirContext::forContext(
3041436c6c9cSStella Laurenzo                 mlirAttributeGetContext(self.namedAttr.attribute));
3042436c6c9cSStella Laurenzo             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3043436c6c9cSStella Laurenzo           },
3044436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(),
3045436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
3046436c6c9cSStella Laurenzo 
3047436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3048436c6c9cSStella Laurenzo   // Mapping of PyType.
3049436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3050f05ff4f7SStella Laurenzo   py::class_<PyType>(m, "Type", py::module_local())
3051b57d6fe4SStella Laurenzo       // Delegate to the PyType copy constructor, which will also lifetime
3052b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirType.
3053b57d6fe4SStella Laurenzo       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3054b57d6fe4SStella Laurenzo            "Casts the passed type to the generic Type")
3055436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3056436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3057436c6c9cSStella Laurenzo       .def_static(
3058436c6c9cSStella Laurenzo           "parse",
3059436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
3060436c6c9cSStella Laurenzo             MlirType type =
3061436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3062436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
3063436c6c9cSStella Laurenzo             // in C API.
3064436c6c9cSStella Laurenzo             if (mlirTypeIsNull(type)) {
3065436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
3066436c6c9cSStella Laurenzo                                Twine("Unable to parse type: '") + typeSpec +
3067436c6c9cSStella Laurenzo                                    "'");
3068436c6c9cSStella Laurenzo             }
3069436c6c9cSStella Laurenzo             return PyType(context->getRef(), type);
3070436c6c9cSStella Laurenzo           },
3071436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
3072436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
3073436c6c9cSStella Laurenzo       .def_property_readonly(
3074436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
3075436c6c9cSStella Laurenzo           "Context that owns the Type")
3076436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3077436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3078f78fe0b7Srkayaith       .def("__hash__",
3079f78fe0b7Srkayaith            [](PyType &self) {
3080f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3081f78fe0b7Srkayaith            })
3082436c6c9cSStella Laurenzo       .def(
3083436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3084436c6c9cSStella Laurenzo       .def(
3085436c6c9cSStella Laurenzo           "__str__",
3086436c6c9cSStella Laurenzo           [](PyType &self) {
3087436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3088436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
3089436c6c9cSStella Laurenzo                           printAccum.getUserData());
3090436c6c9cSStella Laurenzo             return printAccum.join();
3091436c6c9cSStella Laurenzo           },
3092436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
3093436c6c9cSStella Laurenzo       .def("__repr__", [](PyType &self) {
3094436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
3095436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
3096436c6c9cSStella Laurenzo         // However, types are an exception as they typically have compact
3097436c6c9cSStella Laurenzo         // assembly forms and printing them is useful.
3098436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
3099436c6c9cSStella Laurenzo         printAccum.parts.append("Type(");
3100436c6c9cSStella Laurenzo         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3101436c6c9cSStella Laurenzo         printAccum.parts.append(")");
3102436c6c9cSStella Laurenzo         return printAccum.join();
3103436c6c9cSStella Laurenzo       });
3104436c6c9cSStella Laurenzo 
3105436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3106436c6c9cSStella Laurenzo   // Mapping of Value.
3107436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
3108f05ff4f7SStella Laurenzo   py::class_<PyValue>(m, "Value", py::module_local())
31093f3d1c90SMike Urbach       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
31103f3d1c90SMike Urbach       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3111436c6c9cSStella Laurenzo       .def_property_readonly(
3112436c6c9cSStella Laurenzo           "context",
3113436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3114436c6c9cSStella Laurenzo           "Context in which the value lives.")
3115436c6c9cSStella Laurenzo       .def(
3116436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3117436c6c9cSStella Laurenzo           kDumpDocstring)
31185664c5e2SJohn Demme       .def_property_readonly(
31195664c5e2SJohn Demme           "owner",
31205664c5e2SJohn Demme           [](PyValue &self) {
31215664c5e2SJohn Demme             assert(mlirOperationEqual(self.getParentOperation()->get(),
31225664c5e2SJohn Demme                                       mlirOpResultGetOwner(self.get())) &&
31235664c5e2SJohn Demme                    "expected the owner of the value in Python to match that in "
31245664c5e2SJohn Demme                    "the IR");
31255664c5e2SJohn Demme             return self.getParentOperation().getObject();
31265664c5e2SJohn Demme           })
3127436c6c9cSStella Laurenzo       .def("__eq__",
3128436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
3129436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
3130436c6c9cSStella Laurenzo            })
3131436c6c9cSStella Laurenzo       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3132f78fe0b7Srkayaith       .def("__hash__",
3133f78fe0b7Srkayaith            [](PyValue &self) {
3134f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3135f78fe0b7Srkayaith            })
3136436c6c9cSStella Laurenzo       .def(
3137436c6c9cSStella Laurenzo           "__str__",
3138436c6c9cSStella Laurenzo           [](PyValue &self) {
3139436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
3140436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
3141436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
3142436c6c9cSStella Laurenzo                            printAccum.getUserData());
3143436c6c9cSStella Laurenzo             printAccum.parts.append(")");
3144436c6c9cSStella Laurenzo             return printAccum.join();
3145436c6c9cSStella Laurenzo           },
3146436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
3147436c6c9cSStella Laurenzo       .def_property_readonly("type", [](PyValue &self) {
3148436c6c9cSStella Laurenzo         return PyType(self.getParentOperation()->getContext(),
3149436c6c9cSStella Laurenzo                       mlirValueGetType(self.get()));
3150436c6c9cSStella Laurenzo       });
3151436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
3152436c6c9cSStella Laurenzo   PyOpResult::bind(m);
3153436c6c9cSStella Laurenzo 
315430d61893SAlex Zinenko   //----------------------------------------------------------------------------
315530d61893SAlex Zinenko   // Mapping of SymbolTable.
315630d61893SAlex Zinenko   //----------------------------------------------------------------------------
315730d61893SAlex Zinenko   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
315830d61893SAlex Zinenko       .def(py::init<PyOperationBase &>())
315930d61893SAlex Zinenko       .def("__getitem__", &PySymbolTable::dunderGetItem)
3160a6e7d024SStella Laurenzo       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3161a6e7d024SStella Laurenzo       .def("erase", &PySymbolTable::erase, py::arg("operation"))
316230d61893SAlex Zinenko       .def("__delitem__", &PySymbolTable::dunderDel)
3163bdc31837SStella Laurenzo       .def("__contains__",
3164bdc31837SStella Laurenzo            [](PySymbolTable &table, const std::string &name) {
316530d61893SAlex Zinenko              return !mlirOperationIsNull(mlirSymbolTableLookup(
316630d61893SAlex Zinenko                  table, mlirStringRefCreate(name.data(), name.length())));
3167bdc31837SStella Laurenzo            })
3168bdc31837SStella Laurenzo       // Static helpers.
3169bdc31837SStella Laurenzo       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3170bdc31837SStella Laurenzo                   py::arg("symbol"), py::arg("name"))
3171bdc31837SStella Laurenzo       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3172bdc31837SStella Laurenzo                   py::arg("symbol"))
3173bdc31837SStella Laurenzo       .def_static("get_visibility", &PySymbolTable::getVisibility,
3174bdc31837SStella Laurenzo                   py::arg("symbol"))
3175bdc31837SStella Laurenzo       .def_static("set_visibility", &PySymbolTable::setVisibility,
3176bdc31837SStella Laurenzo                   py::arg("symbol"), py::arg("visibility"))
3177bdc31837SStella Laurenzo       .def_static("replace_all_symbol_uses",
3178bdc31837SStella Laurenzo                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3179bdc31837SStella Laurenzo                   py::arg("new_symbol"), py::arg("from_op"))
3180bdc31837SStella Laurenzo       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3181bdc31837SStella Laurenzo                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3182bdc31837SStella Laurenzo                   py::arg("callback"));
318330d61893SAlex Zinenko 
3184436c6c9cSStella Laurenzo   // Container bindings.
3185436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
3186436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
3187436c6c9cSStella Laurenzo   PyBlockList::bind(m);
3188436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
3189436c6c9cSStella Laurenzo   PyOperationList::bind(m);
3190436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
3191436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
3192436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
3193436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
3194436c6c9cSStella Laurenzo   PyRegionList::bind(m);
31954acd8457SAlex Zinenko 
31964acd8457SAlex Zinenko   // Debug bindings.
31974acd8457SAlex Zinenko   PyGlobalDebugFlag::bind(m);
3198436c6c9cSStella Laurenzo }
3199