1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "Globals.h"
12436c6c9cSStella Laurenzo #include "PybindUtils.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
174acd8457SAlex Zinenko #include "mlir-c/Debug.h"
183f3d1c90SMike Urbach #include "mlir-c/IR.h"
19436c6c9cSStella Laurenzo #include "mlir-c/Registration.h"
20e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h"
21436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
22436c6c9cSStella Laurenzo #include <pybind11/stl.h>
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace py = pybind11;
25436c6c9cSStella Laurenzo using namespace mlir;
26436c6c9cSStella Laurenzo using namespace mlir::python;
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo using llvm::SmallVector;
29436c6c9cSStella Laurenzo using llvm::StringRef;
30436c6c9cSStella Laurenzo using llvm::Twine;
31436c6c9cSStella Laurenzo 
32436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
33436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
34436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
35436c6c9cSStella Laurenzo 
36436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
37436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
38436c6c9cSStella Laurenzo 
39436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed.
40436c6c9cSStella Laurenzo 
41436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
42436c6c9cSStella Laurenzo )";
43436c6c9cSStella Laurenzo 
44e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] =
45e67cbbefSJacques Pienaar     R"(Gets a Location representing a caller and callsite)";
46e67cbbefSJacques Pienaar 
47436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
48436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
49436c6c9cSStella Laurenzo 
5004d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] =
5104d76d36SJacques Pienaar     R"(Gets a Location representing a named location with optional child location)";
5204d76d36SJacques Pienaar 
53436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
54436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
55436c6c9cSStella Laurenzo 
56436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails.
57436c6c9cSStella Laurenzo 
58436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
59436c6c9cSStella Laurenzo )";
60436c6c9cSStella Laurenzo 
61436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
62436c6c9cSStella Laurenzo     R"(Creates a new operation.
63436c6c9cSStella Laurenzo 
64436c6c9cSStella Laurenzo Args:
65436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
66436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
67436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
68436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
69436c6c9cSStella Laurenzo   regions: Number of regions to create.
70436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
71436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
72436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
73436c6c9cSStella Laurenzo     context manager).
74436c6c9cSStella Laurenzo Returns:
75436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
76436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
77436c6c9cSStella Laurenzo )";
78436c6c9cSStella Laurenzo 
79436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
80436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
81436c6c9cSStella Laurenzo 
82436c6c9cSStella Laurenzo Args:
83436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
84436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
85436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
86436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
87436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
88436c6c9cSStella Laurenzo     to False.
89436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
90436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
91436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
92436c6c9cSStella Laurenzo     ops. Defaults to False.
93436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
94436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
95436c6c9cSStella Laurenzo     module prints.
96ace1d0adSStella Laurenzo   assume_verified: By default, if not printing generic form, the verifier
97ace1d0adSStella Laurenzo     will be run and if it fails, generic form will be printed with a comment
98ace1d0adSStella Laurenzo     about failed verification. While a reasonable default for interactive use,
99ace1d0adSStella Laurenzo     for systematic use, it is often better for the caller to verify explicitly
100ace1d0adSStella Laurenzo     and report failures in a more robust fashion. Set this to True if doing this
101ace1d0adSStella Laurenzo     in order to avoid running a redundant verification. If the IR is actually
102ace1d0adSStella Laurenzo     invalid, behavior is undefined.
103436c6c9cSStella Laurenzo )";
104436c6c9cSStella Laurenzo 
105436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
106436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
107436c6c9cSStella Laurenzo 
108436c6c9cSStella Laurenzo Args:
109436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
110436c6c9cSStella Laurenzo     False.
111436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
112436c6c9cSStella Laurenzo     configuring the printout.
113436c6c9cSStella Laurenzo Returns:
114436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
115436c6c9cSStella Laurenzo   argument.
116436c6c9cSStella Laurenzo )";
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
119436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
120436c6c9cSStella Laurenzo 
121436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
122436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
123436c6c9cSStella Laurenzo customize behavior.
124436c6c9cSStella Laurenzo )";
125436c6c9cSStella Laurenzo 
126436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
127436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
130436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
131436c6c9cSStella Laurenzo 
132436c6c9cSStella Laurenzo Returns:
133436c6c9cSStella Laurenzo   The created block.
134436c6c9cSStella Laurenzo )";
135436c6c9cSStella Laurenzo 
136436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
137436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
140436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
141436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
142436c6c9cSStella Laurenzo )";
143436c6c9cSStella Laurenzo 
144436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
145436c6c9cSStella Laurenzo // Utilities.
146436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
147436c6c9cSStella Laurenzo 
1484acd8457SAlex Zinenko /// Helper for creating an @classmethod.
149436c6c9cSStella Laurenzo template <class Func, typename... Args>
150436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) {
151436c6c9cSStella Laurenzo   py::object cf = py::cpp_function(f, args...);
152436c6c9cSStella Laurenzo   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
153436c6c9cSStella Laurenzo }
154436c6c9cSStella Laurenzo 
155436c6c9cSStella Laurenzo static py::object
156436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
157436c6c9cSStella Laurenzo                            py::object dialectDescriptor) {
158436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
159436c6c9cSStella Laurenzo   if (!dialectClass) {
160436c6c9cSStella Laurenzo     // Use the base class.
161436c6c9cSStella Laurenzo     return py::cast(PyDialect(std::move(dialectDescriptor)));
162436c6c9cSStella Laurenzo   }
163436c6c9cSStella Laurenzo 
164436c6c9cSStella Laurenzo   // Create the custom implementation.
165436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
166436c6c9cSStella Laurenzo }
167436c6c9cSStella Laurenzo 
168436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
169436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
170436c6c9cSStella Laurenzo }
171436c6c9cSStella Laurenzo 
1724acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag.
1734acd8457SAlex Zinenko struct PyGlobalDebugFlag {
1744acd8457SAlex Zinenko   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
1754acd8457SAlex Zinenko 
1764acd8457SAlex Zinenko   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
1774acd8457SAlex Zinenko 
1784acd8457SAlex Zinenko   static void bind(py::module &m) {
1794acd8457SAlex Zinenko     // Debug flags.
180f05ff4f7SStella Laurenzo     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
1814acd8457SAlex Zinenko         .def_property_static("flag", &PyGlobalDebugFlag::get,
1824acd8457SAlex Zinenko                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
1834acd8457SAlex Zinenko   }
1844acd8457SAlex Zinenko };
1854acd8457SAlex Zinenko 
186436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
187436c6c9cSStella Laurenzo // Collections.
188436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
189436c6c9cSStella Laurenzo 
190436c6c9cSStella Laurenzo namespace {
191436c6c9cSStella Laurenzo 
192436c6c9cSStella Laurenzo class PyRegionIterator {
193436c6c9cSStella Laurenzo public:
194436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
195436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
196436c6c9cSStella Laurenzo 
197436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
198436c6c9cSStella Laurenzo 
199436c6c9cSStella Laurenzo   PyRegion dunderNext() {
200436c6c9cSStella Laurenzo     operation->checkValid();
201436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
202436c6c9cSStella Laurenzo       throw py::stop_iteration();
203436c6c9cSStella Laurenzo     }
204436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
205436c6c9cSStella Laurenzo     return PyRegion(operation, region);
206436c6c9cSStella Laurenzo   }
207436c6c9cSStella Laurenzo 
208436c6c9cSStella Laurenzo   static void bind(py::module &m) {
209f05ff4f7SStella Laurenzo     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
210436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
211436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
212436c6c9cSStella Laurenzo   }
213436c6c9cSStella Laurenzo 
214436c6c9cSStella Laurenzo private:
215436c6c9cSStella Laurenzo   PyOperationRef operation;
216436c6c9cSStella Laurenzo   int nextIndex = 0;
217436c6c9cSStella Laurenzo };
218436c6c9cSStella Laurenzo 
219436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
220436c6c9cSStella Laurenzo /// with a sequence-like container.
221436c6c9cSStella Laurenzo class PyRegionList {
222436c6c9cSStella Laurenzo public:
223436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
224436c6c9cSStella Laurenzo 
225436c6c9cSStella Laurenzo   intptr_t dunderLen() {
226436c6c9cSStella Laurenzo     operation->checkValid();
227436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
228436c6c9cSStella Laurenzo   }
229436c6c9cSStella Laurenzo 
230436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
231436c6c9cSStella Laurenzo     // dunderLen checks validity.
232436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
233436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
234436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
235436c6c9cSStella Laurenzo     }
236436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
237436c6c9cSStella Laurenzo     return PyRegion(operation, region);
238436c6c9cSStella Laurenzo   }
239436c6c9cSStella Laurenzo 
240436c6c9cSStella Laurenzo   static void bind(py::module &m) {
241f05ff4f7SStella Laurenzo     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
242436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
243436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
244436c6c9cSStella Laurenzo   }
245436c6c9cSStella Laurenzo 
246436c6c9cSStella Laurenzo private:
247436c6c9cSStella Laurenzo   PyOperationRef operation;
248436c6c9cSStella Laurenzo };
249436c6c9cSStella Laurenzo 
250436c6c9cSStella Laurenzo class PyBlockIterator {
251436c6c9cSStella Laurenzo public:
252436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
253436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
254436c6c9cSStella Laurenzo 
255436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
256436c6c9cSStella Laurenzo 
257436c6c9cSStella Laurenzo   PyBlock dunderNext() {
258436c6c9cSStella Laurenzo     operation->checkValid();
259436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
260436c6c9cSStella Laurenzo       throw py::stop_iteration();
261436c6c9cSStella Laurenzo     }
262436c6c9cSStella Laurenzo 
263436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
264436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
265436c6c9cSStella Laurenzo     return returnBlock;
266436c6c9cSStella Laurenzo   }
267436c6c9cSStella Laurenzo 
268436c6c9cSStella Laurenzo   static void bind(py::module &m) {
269f05ff4f7SStella Laurenzo     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
270436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
271436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
272436c6c9cSStella Laurenzo   }
273436c6c9cSStella Laurenzo 
274436c6c9cSStella Laurenzo private:
275436c6c9cSStella Laurenzo   PyOperationRef operation;
276436c6c9cSStella Laurenzo   MlirBlock next;
277436c6c9cSStella Laurenzo };
278436c6c9cSStella Laurenzo 
279436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
280436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
281436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
282436c6c9cSStella Laurenzo class PyBlockList {
283436c6c9cSStella Laurenzo public:
284436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
285436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
286436c6c9cSStella Laurenzo 
287436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
288436c6c9cSStella Laurenzo     operation->checkValid();
289436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
290436c6c9cSStella Laurenzo   }
291436c6c9cSStella Laurenzo 
292436c6c9cSStella Laurenzo   intptr_t dunderLen() {
293436c6c9cSStella Laurenzo     operation->checkValid();
294436c6c9cSStella Laurenzo     intptr_t count = 0;
295436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
296436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
297436c6c9cSStella Laurenzo       count += 1;
298436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
299436c6c9cSStella Laurenzo     }
300436c6c9cSStella Laurenzo     return count;
301436c6c9cSStella Laurenzo   }
302436c6c9cSStella Laurenzo 
303436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
304436c6c9cSStella Laurenzo     operation->checkValid();
305436c6c9cSStella Laurenzo     if (index < 0) {
306436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
307436c6c9cSStella Laurenzo                        "attempt to access out of bounds block");
308436c6c9cSStella Laurenzo     }
309436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
310436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
311436c6c9cSStella Laurenzo       if (index == 0) {
312436c6c9cSStella Laurenzo         return PyBlock(operation, block);
313436c6c9cSStella Laurenzo       }
314436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
315436c6c9cSStella Laurenzo       index -= 1;
316436c6c9cSStella Laurenzo     }
317436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
318436c6c9cSStella Laurenzo   }
319436c6c9cSStella Laurenzo 
320436c6c9cSStella Laurenzo   PyBlock appendBlock(py::args pyArgTypes) {
321436c6c9cSStella Laurenzo     operation->checkValid();
322436c6c9cSStella Laurenzo     llvm::SmallVector<MlirType, 4> argTypes;
323436c6c9cSStella Laurenzo     argTypes.reserve(pyArgTypes.size());
324436c6c9cSStella Laurenzo     for (auto &pyArg : pyArgTypes) {
325436c6c9cSStella Laurenzo       argTypes.push_back(pyArg.cast<PyType &>());
326436c6c9cSStella Laurenzo     }
327436c6c9cSStella Laurenzo 
328436c6c9cSStella Laurenzo     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
329436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
330436c6c9cSStella Laurenzo     return PyBlock(operation, block);
331436c6c9cSStella Laurenzo   }
332436c6c9cSStella Laurenzo 
333436c6c9cSStella Laurenzo   static void bind(py::module &m) {
334f05ff4f7SStella Laurenzo     py::class_<PyBlockList>(m, "BlockList", py::module_local())
335436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
336436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
337436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
338436c6c9cSStella Laurenzo         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
339436c6c9cSStella Laurenzo   }
340436c6c9cSStella Laurenzo 
341436c6c9cSStella Laurenzo private:
342436c6c9cSStella Laurenzo   PyOperationRef operation;
343436c6c9cSStella Laurenzo   MlirRegion region;
344436c6c9cSStella Laurenzo };
345436c6c9cSStella Laurenzo 
346436c6c9cSStella Laurenzo class PyOperationIterator {
347436c6c9cSStella Laurenzo public:
348436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
349436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
352436c6c9cSStella Laurenzo 
353436c6c9cSStella Laurenzo   py::object dunderNext() {
354436c6c9cSStella Laurenzo     parentOperation->checkValid();
355436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
356436c6c9cSStella Laurenzo       throw py::stop_iteration();
357436c6c9cSStella Laurenzo     }
358436c6c9cSStella Laurenzo 
359436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
360436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
361436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
362436c6c9cSStella Laurenzo     return returnOperation->createOpView();
363436c6c9cSStella Laurenzo   }
364436c6c9cSStella Laurenzo 
365436c6c9cSStella Laurenzo   static void bind(py::module &m) {
366f05ff4f7SStella Laurenzo     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
367436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
368436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
369436c6c9cSStella Laurenzo   }
370436c6c9cSStella Laurenzo 
371436c6c9cSStella Laurenzo private:
372436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
373436c6c9cSStella Laurenzo   MlirOperation next;
374436c6c9cSStella Laurenzo };
375436c6c9cSStella Laurenzo 
376436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
377436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
378436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
379436c6c9cSStella Laurenzo /// by a block.
380436c6c9cSStella Laurenzo class PyOperationList {
381436c6c9cSStella Laurenzo public:
382436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
383436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
384436c6c9cSStella Laurenzo 
385436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
386436c6c9cSStella Laurenzo     parentOperation->checkValid();
387436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
388436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
389436c6c9cSStella Laurenzo   }
390436c6c9cSStella Laurenzo 
391436c6c9cSStella Laurenzo   intptr_t dunderLen() {
392436c6c9cSStella Laurenzo     parentOperation->checkValid();
393436c6c9cSStella Laurenzo     intptr_t count = 0;
394436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
396436c6c9cSStella Laurenzo       count += 1;
397436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
398436c6c9cSStella Laurenzo     }
399436c6c9cSStella Laurenzo     return count;
400436c6c9cSStella Laurenzo   }
401436c6c9cSStella Laurenzo 
402436c6c9cSStella Laurenzo   py::object dunderGetItem(intptr_t index) {
403436c6c9cSStella Laurenzo     parentOperation->checkValid();
404436c6c9cSStella Laurenzo     if (index < 0) {
405436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
406436c6c9cSStella Laurenzo                        "attempt to access out of bounds operation");
407436c6c9cSStella Laurenzo     }
408436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
409436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
410436c6c9cSStella Laurenzo       if (index == 0) {
411436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
412436c6c9cSStella Laurenzo             ->createOpView();
413436c6c9cSStella Laurenzo       }
414436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
415436c6c9cSStella Laurenzo       index -= 1;
416436c6c9cSStella Laurenzo     }
417436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError,
418436c6c9cSStella Laurenzo                      "attempt to access out of bounds operation");
419436c6c9cSStella Laurenzo   }
420436c6c9cSStella Laurenzo 
421436c6c9cSStella Laurenzo   static void bind(py::module &m) {
422f05ff4f7SStella Laurenzo     py::class_<PyOperationList>(m, "OperationList", py::module_local())
423436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
424436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
425436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
426436c6c9cSStella Laurenzo   }
427436c6c9cSStella Laurenzo 
428436c6c9cSStella Laurenzo private:
429436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
430436c6c9cSStella Laurenzo   MlirBlock block;
431436c6c9cSStella Laurenzo };
432436c6c9cSStella Laurenzo 
433436c6c9cSStella Laurenzo } // namespace
434436c6c9cSStella Laurenzo 
435436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
436436c6c9cSStella Laurenzo // PyMlirContext
437436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
438436c6c9cSStella Laurenzo 
439436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
440436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
441436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
442436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
443436c6c9cSStella Laurenzo }
444436c6c9cSStella Laurenzo 
445436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
446436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
447436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
448436c6c9cSStella Laurenzo   // liveContexts.
449436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
450436c6c9cSStella Laurenzo   getLiveContexts().erase(context.ptr);
451436c6c9cSStella Laurenzo   mlirContextDestroy(context);
452436c6c9cSStella Laurenzo }
453436c6c9cSStella Laurenzo 
454436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() {
455436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
456436c6c9cSStella Laurenzo }
457436c6c9cSStella Laurenzo 
458436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) {
459436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
460436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
461436c6c9cSStella Laurenzo     throw py::error_already_set();
462436c6c9cSStella Laurenzo   return forContext(rawContext).releaseObject();
463436c6c9cSStella Laurenzo }
464436c6c9cSStella Laurenzo 
465436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() {
466436c6c9cSStella Laurenzo   MlirContext context = mlirContextCreate();
467436c6c9cSStella Laurenzo   mlirRegisterAllDialects(context);
468436c6c9cSStella Laurenzo   return new PyMlirContext(context);
469436c6c9cSStella Laurenzo }
470436c6c9cSStella Laurenzo 
471436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
472436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
473436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
474436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
475436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
476436c6c9cSStella Laurenzo     // Create.
477436c6c9cSStella Laurenzo     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
478436c6c9cSStella Laurenzo     py::object pyRef = py::cast(unownedContextWrapper);
479436c6c9cSStella Laurenzo     assert(pyRef && "cast to py::object failed");
480436c6c9cSStella Laurenzo     liveContexts[context.ptr] = unownedContextWrapper;
481436c6c9cSStella Laurenzo     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
482436c6c9cSStella Laurenzo   }
483436c6c9cSStella Laurenzo   // Use existing.
484436c6c9cSStella Laurenzo   py::object pyRef = py::cast(it->second);
485436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
486436c6c9cSStella Laurenzo }
487436c6c9cSStella Laurenzo 
488436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
489436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
490436c6c9cSStella Laurenzo   return liveContexts;
491436c6c9cSStella Laurenzo }
492436c6c9cSStella Laurenzo 
493436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
494436c6c9cSStella Laurenzo 
495436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
496436c6c9cSStella Laurenzo 
497436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
498436c6c9cSStella Laurenzo 
499436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() {
500436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushContext(*this);
501436c6c9cSStella Laurenzo }
502436c6c9cSStella Laurenzo 
503436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType,
504436c6c9cSStella Laurenzo                                 pybind11::object excVal,
505436c6c9cSStella Laurenzo                                 pybind11::object excTb) {
506436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
507436c6c9cSStella Laurenzo }
508436c6c9cSStella Laurenzo 
509436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
510436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
511436c6c9cSStella Laurenzo   if (!context) {
512436c6c9cSStella Laurenzo     throw SetPyError(
513436c6c9cSStella Laurenzo         PyExc_RuntimeError,
514436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
515436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
516436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
517436c6c9cSStella Laurenzo   }
518436c6c9cSStella Laurenzo   return *context;
519436c6c9cSStella Laurenzo }
520436c6c9cSStella Laurenzo 
521436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
522436c6c9cSStella Laurenzo // PyThreadContextEntry management
523436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
524436c6c9cSStella Laurenzo 
525436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
526436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
527436c6c9cSStella Laurenzo   return stack;
528436c6c9cSStella Laurenzo }
529436c6c9cSStella Laurenzo 
530436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
531436c6c9cSStella Laurenzo   auto &stack = getStack();
532436c6c9cSStella Laurenzo   if (stack.empty())
533436c6c9cSStella Laurenzo     return nullptr;
534436c6c9cSStella Laurenzo   return &stack.back();
535436c6c9cSStella Laurenzo }
536436c6c9cSStella Laurenzo 
537436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
538436c6c9cSStella Laurenzo                                 py::object insertionPoint,
539436c6c9cSStella Laurenzo                                 py::object location) {
540436c6c9cSStella Laurenzo   auto &stack = getStack();
541436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
542436c6c9cSStella Laurenzo                      std::move(location));
543436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
544436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
545436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
546436c6c9cSStella Laurenzo   if (stack.size() > 1) {
547436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
548436c6c9cSStella Laurenzo     auto &current = stack.back();
549436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
550436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
551436c6c9cSStella Laurenzo       if (!current.insertionPoint)
552436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
553436c6c9cSStella Laurenzo       if (!current.location)
554436c6c9cSStella Laurenzo         current.location = prev.location;
555436c6c9cSStella Laurenzo     }
556436c6c9cSStella Laurenzo   }
557436c6c9cSStella Laurenzo }
558436c6c9cSStella Laurenzo 
559436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
560436c6c9cSStella Laurenzo   if (!context)
561436c6c9cSStella Laurenzo     return nullptr;
562436c6c9cSStella Laurenzo   return py::cast<PyMlirContext *>(context);
563436c6c9cSStella Laurenzo }
564436c6c9cSStella Laurenzo 
565436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
566436c6c9cSStella Laurenzo   if (!insertionPoint)
567436c6c9cSStella Laurenzo     return nullptr;
568436c6c9cSStella Laurenzo   return py::cast<PyInsertionPoint *>(insertionPoint);
569436c6c9cSStella Laurenzo }
570436c6c9cSStella Laurenzo 
571436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
572436c6c9cSStella Laurenzo   if (!location)
573436c6c9cSStella Laurenzo     return nullptr;
574436c6c9cSStella Laurenzo   return py::cast<PyLocation *>(location);
575436c6c9cSStella Laurenzo }
576436c6c9cSStella Laurenzo 
577436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
578436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
579436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
580436c6c9cSStella Laurenzo }
581436c6c9cSStella Laurenzo 
582436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
583436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
584436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
585436c6c9cSStella Laurenzo }
586436c6c9cSStella Laurenzo 
587436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
588436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
589436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
590436c6c9cSStella Laurenzo }
591436c6c9cSStella Laurenzo 
592436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
593436c6c9cSStella Laurenzo   py::object contextObj = py::cast(context);
594436c6c9cSStella Laurenzo   push(FrameKind::Context, /*context=*/contextObj,
595436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
596436c6c9cSStella Laurenzo        /*location=*/py::object());
597436c6c9cSStella Laurenzo   return contextObj;
598436c6c9cSStella Laurenzo }
599436c6c9cSStella Laurenzo 
600436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
601436c6c9cSStella Laurenzo   auto &stack = getStack();
602436c6c9cSStella Laurenzo   if (stack.empty())
603436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
604436c6c9cSStella Laurenzo   auto &tos = stack.back();
605436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
606436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
607436c6c9cSStella Laurenzo   stack.pop_back();
608436c6c9cSStella Laurenzo }
609436c6c9cSStella Laurenzo 
610436c6c9cSStella Laurenzo py::object
611436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
612436c6c9cSStella Laurenzo   py::object contextObj =
613436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
614436c6c9cSStella Laurenzo   py::object insertionPointObj = py::cast(insertionPoint);
615436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
616436c6c9cSStella Laurenzo        /*context=*/contextObj,
617436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
618436c6c9cSStella Laurenzo        /*location=*/py::object());
619436c6c9cSStella Laurenzo   return insertionPointObj;
620436c6c9cSStella Laurenzo }
621436c6c9cSStella Laurenzo 
622436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
623436c6c9cSStella Laurenzo   auto &stack = getStack();
624436c6c9cSStella Laurenzo   if (stack.empty())
625436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
626436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
627436c6c9cSStella Laurenzo   auto &tos = stack.back();
628436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
629436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
630436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
631436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
632436c6c9cSStella Laurenzo   stack.pop_back();
633436c6c9cSStella Laurenzo }
634436c6c9cSStella Laurenzo 
635436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
636436c6c9cSStella Laurenzo   py::object contextObj = location.getContext().getObject();
637436c6c9cSStella Laurenzo   py::object locationObj = py::cast(location);
638436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
639436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
640436c6c9cSStella Laurenzo        /*location=*/locationObj);
641436c6c9cSStella Laurenzo   return locationObj;
642436c6c9cSStella Laurenzo }
643436c6c9cSStella Laurenzo 
644436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
645436c6c9cSStella Laurenzo   auto &stack = getStack();
646436c6c9cSStella Laurenzo   if (stack.empty())
647436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
648436c6c9cSStella Laurenzo   auto &tos = stack.back();
649436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
650436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
651436c6c9cSStella Laurenzo   stack.pop_back();
652436c6c9cSStella Laurenzo }
653436c6c9cSStella Laurenzo 
654436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
655436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects
656436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
657436c6c9cSStella Laurenzo 
658436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
659436c6c9cSStella Laurenzo                                          bool attrError) {
660f8479d9dSRiver Riddle   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
661f8479d9dSRiver Riddle                                                     {key.data(), key.size()});
662436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
663436c6c9cSStella Laurenzo     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
664436c6c9cSStella Laurenzo                      Twine("Dialect '") + key + "' not found");
665436c6c9cSStella Laurenzo   }
666436c6c9cSStella Laurenzo   return dialect;
667436c6c9cSStella Laurenzo }
668436c6c9cSStella Laurenzo 
669436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
670436c6c9cSStella Laurenzo // PyLocation
671436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
672436c6c9cSStella Laurenzo 
673436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() {
674436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
675436c6c9cSStella Laurenzo }
676436c6c9cSStella Laurenzo 
677436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) {
678436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
679436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
680436c6c9cSStella Laurenzo     throw py::error_already_set();
681436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
682436c6c9cSStella Laurenzo                     rawLoc);
683436c6c9cSStella Laurenzo }
684436c6c9cSStella Laurenzo 
685436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() {
686436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushLocation(*this);
687436c6c9cSStella Laurenzo }
688436c6c9cSStella Laurenzo 
689436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal,
690436c6c9cSStella Laurenzo                              py::object excTb) {
691436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
692436c6c9cSStella Laurenzo }
693436c6c9cSStella Laurenzo 
694436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
695436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
696436c6c9cSStella Laurenzo   if (!location) {
697436c6c9cSStella Laurenzo     throw SetPyError(
698436c6c9cSStella Laurenzo         PyExc_RuntimeError,
699436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
700436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
701436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
702436c6c9cSStella Laurenzo   }
703436c6c9cSStella Laurenzo   return *location;
704436c6c9cSStella Laurenzo }
705436c6c9cSStella Laurenzo 
706436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
707436c6c9cSStella Laurenzo // PyModule
708436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
709436c6c9cSStella Laurenzo 
710436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
711436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
712436c6c9cSStella Laurenzo 
713436c6c9cSStella Laurenzo PyModule::~PyModule() {
714436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
715436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
716436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
717436c6c9cSStella Laurenzo          "destroying module not in live map");
718436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
719436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
720436c6c9cSStella Laurenzo }
721436c6c9cSStella Laurenzo 
722436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
723436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
724436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
725436c6c9cSStella Laurenzo 
726436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
727436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
728436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
729436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
730436c6c9cSStella Laurenzo     // Create.
731436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
732436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
733436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
734436c6c9cSStella Laurenzo     // Just be explicit.
735436c6c9cSStella Laurenzo     py::object pyRef =
736436c6c9cSStella Laurenzo         py::cast(unownedModule, py::return_value_policy::take_ownership);
737436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
738436c6c9cSStella Laurenzo     liveModules[module.ptr] =
739436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
740436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
741436c6c9cSStella Laurenzo   }
742436c6c9cSStella Laurenzo   // Use existing.
743436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
744436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
745436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
746436c6c9cSStella Laurenzo }
747436c6c9cSStella Laurenzo 
748436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) {
749436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
750436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
751436c6c9cSStella Laurenzo     throw py::error_already_set();
752436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
753436c6c9cSStella Laurenzo }
754436c6c9cSStella Laurenzo 
755436c6c9cSStella Laurenzo py::object PyModule::getCapsule() {
756436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
757436c6c9cSStella Laurenzo }
758436c6c9cSStella Laurenzo 
759436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
760436c6c9cSStella Laurenzo // PyOperation
761436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
762436c6c9cSStella Laurenzo 
763436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
764436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
765436c6c9cSStella Laurenzo 
766436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
76749745f87SMike Urbach   // If the operation has already been invalidated there is nothing to do.
76849745f87SMike Urbach   if (!valid)
76949745f87SMike Urbach     return;
770436c6c9cSStella Laurenzo   auto &liveOperations = getContext()->liveOperations;
771436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 1 &&
772436c6c9cSStella Laurenzo          "destroying operation not in live map");
773436c6c9cSStella Laurenzo   liveOperations.erase(operation.ptr);
774436c6c9cSStella Laurenzo   if (!isAttached()) {
775436c6c9cSStella Laurenzo     mlirOperationDestroy(operation);
776436c6c9cSStella Laurenzo   }
777436c6c9cSStella Laurenzo }
778436c6c9cSStella Laurenzo 
779436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
780436c6c9cSStella Laurenzo                                            MlirOperation operation,
781436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
782436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
783436c6c9cSStella Laurenzo   // Create.
784436c6c9cSStella Laurenzo   PyOperation *unownedOperation =
785436c6c9cSStella Laurenzo       new PyOperation(std::move(contextRef), operation);
786436c6c9cSStella Laurenzo   // Note that the default return value policy on cast is automatic_reference,
787436c6c9cSStella Laurenzo   // which does not take ownership (delete will not be called).
788436c6c9cSStella Laurenzo   // Just be explicit.
789436c6c9cSStella Laurenzo   py::object pyRef =
790436c6c9cSStella Laurenzo       py::cast(unownedOperation, py::return_value_policy::take_ownership);
791436c6c9cSStella Laurenzo   unownedOperation->handle = pyRef;
792436c6c9cSStella Laurenzo   if (parentKeepAlive) {
793436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
794436c6c9cSStella Laurenzo   }
795436c6c9cSStella Laurenzo   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
796436c6c9cSStella Laurenzo   return PyOperationRef(unownedOperation, std::move(pyRef));
797436c6c9cSStella Laurenzo }
798436c6c9cSStella Laurenzo 
799436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
800436c6c9cSStella Laurenzo                                          MlirOperation operation,
801436c6c9cSStella Laurenzo                                          py::object parentKeepAlive) {
802436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
803436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
804436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
805436c6c9cSStella Laurenzo     // Create.
806436c6c9cSStella Laurenzo     return createInstance(std::move(contextRef), operation,
807436c6c9cSStella Laurenzo                           std::move(parentKeepAlive));
808436c6c9cSStella Laurenzo   }
809436c6c9cSStella Laurenzo   // Use existing.
810436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
811436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
812436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
813436c6c9cSStella Laurenzo }
814436c6c9cSStella Laurenzo 
815436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
816436c6c9cSStella Laurenzo                                            MlirOperation operation,
817436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
818436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
819436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
820436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
821436c6c9cSStella Laurenzo   (void)liveOperations;
822436c6c9cSStella Laurenzo 
823436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
824436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
825436c6c9cSStella Laurenzo   created->attached = false;
826436c6c9cSStella Laurenzo   return created;
827436c6c9cSStella Laurenzo }
828436c6c9cSStella Laurenzo 
829436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
830436c6c9cSStella Laurenzo   if (!valid) {
831436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
832436c6c9cSStella Laurenzo   }
833436c6c9cSStella Laurenzo }
834436c6c9cSStella Laurenzo 
835436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary,
836436c6c9cSStella Laurenzo                             llvm::Optional<int64_t> largeElementsLimit,
837436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
838ace1d0adSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope,
839ace1d0adSStella Laurenzo                             bool assumeVerified) {
840436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
841436c6c9cSStella Laurenzo   operation.checkValid();
842436c6c9cSStella Laurenzo   if (fileObject.is_none())
843436c6c9cSStella Laurenzo     fileObject = py::module::import("sys").attr("stdout");
844436c6c9cSStella Laurenzo 
845ace1d0adSStella Laurenzo   if (!assumeVerified && !printGenericOpForm &&
846ace1d0adSStella Laurenzo       !mlirOperationVerify(operation)) {
847ace1d0adSStella Laurenzo     std::string message("// Verification failed, printing generic form\n");
848ace1d0adSStella Laurenzo     if (binary) {
849ace1d0adSStella Laurenzo       fileObject.attr("write")(py::bytes(message));
850ace1d0adSStella Laurenzo     } else {
851ace1d0adSStella Laurenzo       fileObject.attr("write")(py::str(message));
852ace1d0adSStella Laurenzo     }
853436c6c9cSStella Laurenzo     printGenericOpForm = true;
854436c6c9cSStella Laurenzo   }
855436c6c9cSStella Laurenzo 
856436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
857436c6c9cSStella Laurenzo   if (largeElementsLimit)
858436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
859436c6c9cSStella Laurenzo   if (enableDebugInfo)
860436c6c9cSStella Laurenzo     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
861436c6c9cSStella Laurenzo   if (printGenericOpForm)
862436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
863436c6c9cSStella Laurenzo 
864436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
865436c6c9cSStella Laurenzo   py::gil_scoped_release();
866436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
867436c6c9cSStella Laurenzo                               accum.getUserData());
868436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
869436c6c9cSStella Laurenzo }
870436c6c9cSStella Laurenzo 
871436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary,
872436c6c9cSStella Laurenzo                                    llvm::Optional<int64_t> largeElementsLimit,
873436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
874ace1d0adSStella Laurenzo                                    bool printGenericOpForm, bool useLocalScope,
875ace1d0adSStella Laurenzo                                    bool assumeVerified) {
876436c6c9cSStella Laurenzo   py::object fileObject;
877436c6c9cSStella Laurenzo   if (binary) {
878436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("BytesIO")();
879436c6c9cSStella Laurenzo   } else {
880436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("StringIO")();
881436c6c9cSStella Laurenzo   }
882436c6c9cSStella Laurenzo   print(fileObject, /*binary=*/binary,
883436c6c9cSStella Laurenzo         /*largeElementsLimit=*/largeElementsLimit,
884436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
885436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
886436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
887ace1d0adSStella Laurenzo         /*useLocalScope=*/useLocalScope,
888ace1d0adSStella Laurenzo         /*assumeVerified=*/assumeVerified);
889436c6c9cSStella Laurenzo 
890436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
891436c6c9cSStella Laurenzo }
892436c6c9cSStella Laurenzo 
89324685aaeSAlex Zinenko void PyOperationBase::moveAfter(PyOperationBase &other) {
89424685aaeSAlex Zinenko   PyOperation &operation = getOperation();
89524685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
89624685aaeSAlex Zinenko   operation.checkValid();
89724685aaeSAlex Zinenko   otherOp.checkValid();
89824685aaeSAlex Zinenko   mlirOperationMoveAfter(operation, otherOp);
89924685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
90024685aaeSAlex Zinenko }
90124685aaeSAlex Zinenko 
90224685aaeSAlex Zinenko void PyOperationBase::moveBefore(PyOperationBase &other) {
90324685aaeSAlex Zinenko   PyOperation &operation = getOperation();
90424685aaeSAlex Zinenko   PyOperation &otherOp = other.getOperation();
90524685aaeSAlex Zinenko   operation.checkValid();
90624685aaeSAlex Zinenko   otherOp.checkValid();
90724685aaeSAlex Zinenko   mlirOperationMoveBefore(operation, otherOp);
90824685aaeSAlex Zinenko   operation.parentKeepAlive = otherOp.parentKeepAlive;
90924685aaeSAlex Zinenko }
91024685aaeSAlex Zinenko 
9111689dadeSJohn Demme llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
91249745f87SMike Urbach   checkValid();
913436c6c9cSStella Laurenzo   if (!isAttached())
914436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
915436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
916436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
9171689dadeSJohn Demme     return {};
918436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
919436c6c9cSStella Laurenzo }
920436c6c9cSStella Laurenzo 
921436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
92249745f87SMike Urbach   checkValid();
9231689dadeSJohn Demme   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
924436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
925436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
9261689dadeSJohn Demme   assert(parentOperation && "Operation has no parent");
9271689dadeSJohn Demme   return PyBlock{std::move(*parentOperation), block};
928436c6c9cSStella Laurenzo }
929436c6c9cSStella Laurenzo 
9300126e906SJohn Demme py::object PyOperation::getCapsule() {
93149745f87SMike Urbach   checkValid();
9320126e906SJohn Demme   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
9330126e906SJohn Demme }
9340126e906SJohn Demme 
9350126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) {
9360126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
9370126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
9380126e906SJohn Demme     throw py::error_already_set();
9390126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
9400126e906SJohn Demme   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
9410126e906SJohn Demme       .releaseObject();
9420126e906SJohn Demme }
9430126e906SJohn Demme 
944436c6c9cSStella Laurenzo py::object PyOperation::create(
945436c6c9cSStella Laurenzo     std::string name, llvm::Optional<std::vector<PyType *>> results,
946436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyValue *>> operands,
947436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
948436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
949436c6c9cSStella Laurenzo     DefaultingPyLocation location, py::object maybeIp) {
950436c6c9cSStella Laurenzo   llvm::SmallVector<MlirValue, 4> mlirOperands;
951436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
952436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
954436c6c9cSStella Laurenzo 
955436c6c9cSStella Laurenzo   // General parameter validation.
956436c6c9cSStella Laurenzo   if (regions < 0)
957436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
958436c6c9cSStella Laurenzo 
959436c6c9cSStella Laurenzo   // Unpack/validate operands.
960436c6c9cSStella Laurenzo   if (operands) {
961436c6c9cSStella Laurenzo     mlirOperands.reserve(operands->size());
962436c6c9cSStella Laurenzo     for (PyValue *operand : *operands) {
963436c6c9cSStella Laurenzo       if (!operand)
964436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
965436c6c9cSStella Laurenzo       mlirOperands.push_back(operand->get());
966436c6c9cSStella Laurenzo     }
967436c6c9cSStella Laurenzo   }
968436c6c9cSStella Laurenzo 
969436c6c9cSStella Laurenzo   // Unpack/validate results.
970436c6c9cSStella Laurenzo   if (results) {
971436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
972436c6c9cSStella Laurenzo     for (PyType *result : *results) {
973436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
974436c6c9cSStella Laurenzo       if (!result)
975436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "result type cannot be None");
976436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
977436c6c9cSStella Laurenzo     }
978436c6c9cSStella Laurenzo   }
979436c6c9cSStella Laurenzo   // Unpack/validate attributes.
980436c6c9cSStella Laurenzo   if (attributes) {
981436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
982436c6c9cSStella Laurenzo     for (auto &it : *attributes) {
983436c6c9cSStella Laurenzo       std::string key;
984436c6c9cSStella Laurenzo       try {
985436c6c9cSStella Laurenzo         key = it.first.cast<std::string>();
986436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
987436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
988436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
989436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
990436c6c9cSStella Laurenzo         throw py::cast_error(msg);
991436c6c9cSStella Laurenzo       }
992436c6c9cSStella Laurenzo       try {
993436c6c9cSStella Laurenzo         auto &attribute = it.second.cast<PyAttribute &>();
994436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
995436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
996436c6c9cSStella Laurenzo       } catch (py::reference_cast_error &) {
997436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
998436c6c9cSStella Laurenzo         std::string msg =
999436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
1000436c6c9cSStella Laurenzo             "\" when attempting to create the operation \"" + name + "\"";
1001436c6c9cSStella Laurenzo         throw py::cast_error(msg);
1002436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1003436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute value for the key \"" + key +
1004436c6c9cSStella Laurenzo                           "\" when attempting to create the operation \"" +
1005436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
1006436c6c9cSStella Laurenzo         throw py::cast_error(msg);
1007436c6c9cSStella Laurenzo       }
1008436c6c9cSStella Laurenzo     }
1009436c6c9cSStella Laurenzo   }
1010436c6c9cSStella Laurenzo   // Unpack/validate successors.
1011436c6c9cSStella Laurenzo   if (successors) {
1012436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
1013436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
1014436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
1015436c6c9cSStella Laurenzo       if (!successor)
1016436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1017436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
1018436c6c9cSStella Laurenzo     }
1019436c6c9cSStella Laurenzo   }
1020436c6c9cSStella Laurenzo 
1021436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
1022436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
1023436c6c9cSStella Laurenzo   MlirOperationState state =
1024436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
1025436c6c9cSStella Laurenzo   if (!mlirOperands.empty())
1026436c6c9cSStella Laurenzo     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1027436c6c9cSStella Laurenzo                                   mlirOperands.data());
1028436c6c9cSStella Laurenzo   if (!mlirResults.empty())
1029436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
1030436c6c9cSStella Laurenzo                                  mlirResults.data());
1031436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
1032436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
1033436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
1034436c6c9cSStella Laurenzo     // on.
1035436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1036436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
1037436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
1038436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1039436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1040436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
1041436c6c9cSStella Laurenzo           it.second));
1042436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1043436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1044436c6c9cSStella Laurenzo   }
1045436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
1046436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1047436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
1048436c6c9cSStella Laurenzo   if (regions) {
1049436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1050436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
1051436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
1052436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
1053436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1054436c6c9cSStella Laurenzo                                       mlirRegions.data());
1055436c6c9cSStella Laurenzo   }
1056436c6c9cSStella Laurenzo 
1057436c6c9cSStella Laurenzo   // Construct the operation.
1058436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1059436c6c9cSStella Laurenzo   PyOperationRef created =
1060436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1061436c6c9cSStella Laurenzo 
1062436c6c9cSStella Laurenzo   // InsertPoint active?
1063436c6c9cSStella Laurenzo   if (!maybeIp.is(py::cast(false))) {
1064436c6c9cSStella Laurenzo     PyInsertionPoint *ip;
1065436c6c9cSStella Laurenzo     if (maybeIp.is_none()) {
1066436c6c9cSStella Laurenzo       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1067436c6c9cSStella Laurenzo     } else {
1068436c6c9cSStella Laurenzo       ip = py::cast<PyInsertionPoint *>(maybeIp);
1069436c6c9cSStella Laurenzo     }
1070436c6c9cSStella Laurenzo     if (ip)
1071436c6c9cSStella Laurenzo       ip->insert(*created.get());
1072436c6c9cSStella Laurenzo   }
1073436c6c9cSStella Laurenzo 
1074436c6c9cSStella Laurenzo   return created->createOpView();
1075436c6c9cSStella Laurenzo }
1076436c6c9cSStella Laurenzo 
1077436c6c9cSStella Laurenzo py::object PyOperation::createOpView() {
107849745f87SMike Urbach   checkValid();
1079436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1080436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1081436c6c9cSStella Laurenzo   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1082436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1083436c6c9cSStella Laurenzo   if (opViewClass)
1084436c6c9cSStella Laurenzo     return (*opViewClass)(getRef().getObject());
1085436c6c9cSStella Laurenzo   return py::cast(PyOpView(getRef().getObject()));
1086436c6c9cSStella Laurenzo }
1087436c6c9cSStella Laurenzo 
108849745f87SMike Urbach void PyOperation::erase() {
108949745f87SMike Urbach   checkValid();
109049745f87SMike Urbach   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
109149745f87SMike Urbach   // Python reference to a child operation is live. All children should also
109249745f87SMike Urbach   // have their `valid` bit set to false.
109349745f87SMike Urbach   auto &liveOperations = getContext()->liveOperations;
109449745f87SMike Urbach   if (liveOperations.count(operation.ptr))
109549745f87SMike Urbach     liveOperations.erase(operation.ptr);
109649745f87SMike Urbach   mlirOperationDestroy(operation);
109749745f87SMike Urbach   valid = false;
109849745f87SMike Urbach }
109949745f87SMike Urbach 
1100436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1101436c6c9cSStella Laurenzo // PyOpView
1102436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1103436c6c9cSStella Laurenzo 
1104436c6c9cSStella Laurenzo py::object
1105436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1106436c6c9cSStella Laurenzo                        py::list operandList,
1107436c6c9cSStella Laurenzo                        llvm::Optional<py::dict> attributes,
1108436c6c9cSStella Laurenzo                        llvm::Optional<std::vector<PyBlock *>> successors,
1109436c6c9cSStella Laurenzo                        llvm::Optional<int> regions,
1110436c6c9cSStella Laurenzo                        DefaultingPyLocation location, py::object maybeIp) {
1111436c6c9cSStella Laurenzo   PyMlirContextRef context = location->getContext();
1112436c6c9cSStella Laurenzo   // Class level operation construction metadata.
1113436c6c9cSStella Laurenzo   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1114436c6c9cSStella Laurenzo   // Operand and result segment specs are either none, which does no
1115436c6c9cSStella Laurenzo   // variadic unpacking, or a list of ints with segment sizes, where each
1116436c6c9cSStella Laurenzo   // element is either a positive number (typically 1 for a scalar) or -1 to
1117436c6c9cSStella Laurenzo   // indicate that it is derived from the length of the same-indexed operand
1118436c6c9cSStella Laurenzo   // or result (implying that it is a list at that position).
1119436c6c9cSStella Laurenzo   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1120436c6c9cSStella Laurenzo   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1121436c6c9cSStella Laurenzo 
11228d05a288SStella Laurenzo   std::vector<uint32_t> operandSegmentLengths;
11238d05a288SStella Laurenzo   std::vector<uint32_t> resultSegmentLengths;
1124436c6c9cSStella Laurenzo 
1125436c6c9cSStella Laurenzo   // Validate/determine region count.
1126436c6c9cSStella Laurenzo   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1127436c6c9cSStella Laurenzo   int opMinRegionCount = std::get<0>(opRegionSpec);
1128436c6c9cSStella Laurenzo   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1129436c6c9cSStella Laurenzo   if (!regions) {
1130436c6c9cSStella Laurenzo     regions = opMinRegionCount;
1131436c6c9cSStella Laurenzo   }
1132436c6c9cSStella Laurenzo   if (*regions < opMinRegionCount) {
1133436c6c9cSStella Laurenzo     throw py::value_error(
1134436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1135436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1136436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1137436c6c9cSStella Laurenzo             .str());
1138436c6c9cSStella Laurenzo   }
1139436c6c9cSStella Laurenzo   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1140436c6c9cSStella Laurenzo     throw py::value_error(
1141436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1142436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1143436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1144436c6c9cSStella Laurenzo             .str());
1145436c6c9cSStella Laurenzo   }
1146436c6c9cSStella Laurenzo 
1147436c6c9cSStella Laurenzo   // Unpack results.
1148436c6c9cSStella Laurenzo   std::vector<PyType *> resultTypes;
1149436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1150436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1151436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1152436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(resultTypeList)) {
1153436c6c9cSStella Laurenzo       try {
1154436c6c9cSStella Laurenzo         resultTypes.push_back(py::cast<PyType *>(it.value()));
1155436c6c9cSStella Laurenzo         if (!resultTypes.back())
1156436c6c9cSStella Laurenzo           throw py::cast_error();
1157436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1158436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Result ") +
1159436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1160436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1161436c6c9cSStella Laurenzo                                   .str());
1162436c6c9cSStella Laurenzo       }
1163436c6c9cSStella Laurenzo     }
1164436c6c9cSStella Laurenzo   } else {
1165436c6c9cSStella Laurenzo     // Sized result unpacking.
1166436c6c9cSStella Laurenzo     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1167436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1168436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1169436c6c9cSStella Laurenzo                              "\" requires " +
1170436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1171436c6c9cSStella Laurenzo                              " result segments but was provided " +
1172436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1173436c6c9cSStella Laurenzo                                 .str());
1174436c6c9cSStella Laurenzo     }
1175436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1176436c6c9cSStella Laurenzo     for (auto it :
1177436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1178436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1179436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1180436c6c9cSStella Laurenzo         // Unpack unary element.
1181436c6c9cSStella Laurenzo         try {
11826981e5ecSAlex Zinenko           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1183436c6c9cSStella Laurenzo           if (resultType) {
1184436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1185436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1186436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1187436c6c9cSStella Laurenzo             // Allowed to be optional.
1188436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1189436c6c9cSStella Laurenzo           } else {
1190436c6c9cSStella Laurenzo             throw py::cast_error("was None and result is not optional");
1191436c6c9cSStella Laurenzo           }
1192436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1193436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1194436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1195436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1196436c6c9cSStella Laurenzo                                  ")")
1197436c6c9cSStella Laurenzo                                     .str());
1198436c6c9cSStella Laurenzo         }
1199436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1200436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1201436c6c9cSStella Laurenzo         try {
1202436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1203436c6c9cSStella Laurenzo             // Treat it as an empty list.
1204436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1205436c6c9cSStella Laurenzo           } else {
1206436c6c9cSStella Laurenzo             // Unpack the list.
1207436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1208436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1209436c6c9cSStella Laurenzo               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1210436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1211436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1212436c6c9cSStella Laurenzo               }
1213436c6c9cSStella Laurenzo             }
1214436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(segment.size());
1215436c6c9cSStella Laurenzo           }
1216436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1217436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1218436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1219436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1220436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1221436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1222436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1223436c6c9cSStella Laurenzo                                  err.what() + ")")
1224436c6c9cSStella Laurenzo                                     .str());
1225436c6c9cSStella Laurenzo         }
1226436c6c9cSStella Laurenzo       } else {
1227436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1228436c6c9cSStella Laurenzo       }
1229436c6c9cSStella Laurenzo     }
1230436c6c9cSStella Laurenzo   }
1231436c6c9cSStella Laurenzo 
1232436c6c9cSStella Laurenzo   // Unpack operands.
1233436c6c9cSStella Laurenzo   std::vector<PyValue *> operands;
1234436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1235436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1236436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1237436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(operandList)) {
1238436c6c9cSStella Laurenzo       try {
1239436c6c9cSStella Laurenzo         operands.push_back(py::cast<PyValue *>(it.value()));
1240436c6c9cSStella Laurenzo         if (!operands.back())
1241436c6c9cSStella Laurenzo           throw py::cast_error();
1242436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1243436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Operand ") +
1244436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1245436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1246436c6c9cSStella Laurenzo                                   .str());
1247436c6c9cSStella Laurenzo       }
1248436c6c9cSStella Laurenzo     }
1249436c6c9cSStella Laurenzo   } else {
1250436c6c9cSStella Laurenzo     // Sized operand unpacking.
1251436c6c9cSStella Laurenzo     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1252436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1253436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1254436c6c9cSStella Laurenzo                              "\" requires " +
1255436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1256436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1257436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1258436c6c9cSStella Laurenzo                                 .str());
1259436c6c9cSStella Laurenzo     }
1260436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1261436c6c9cSStella Laurenzo     for (auto it :
1262436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1263436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1264436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1265436c6c9cSStella Laurenzo         // Unpack unary element.
1266436c6c9cSStella Laurenzo         try {
1267436c6c9cSStella Laurenzo           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1268436c6c9cSStella Laurenzo           if (operandValue) {
1269436c6c9cSStella Laurenzo             operands.push_back(operandValue);
1270436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(1);
1271436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1272436c6c9cSStella Laurenzo             // Allowed to be optional.
1273436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1274436c6c9cSStella Laurenzo           } else {
1275436c6c9cSStella Laurenzo             throw py::cast_error("was None and operand is not optional");
1276436c6c9cSStella Laurenzo           }
1277436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1278436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1279436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1280436c6c9cSStella Laurenzo                                  name + "\" must be a Value (" + err.what() +
1281436c6c9cSStella Laurenzo                                  ")")
1282436c6c9cSStella Laurenzo                                     .str());
1283436c6c9cSStella Laurenzo         }
1284436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1285436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1286436c6c9cSStella Laurenzo         try {
1287436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1288436c6c9cSStella Laurenzo             // Treat it as an empty list.
1289436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1290436c6c9cSStella Laurenzo           } else {
1291436c6c9cSStella Laurenzo             // Unpack the list.
1292436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1293436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1294436c6c9cSStella Laurenzo               operands.push_back(py::cast<PyValue *>(segmentItem));
1295436c6c9cSStella Laurenzo               if (!operands.back()) {
1296436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1297436c6c9cSStella Laurenzo               }
1298436c6c9cSStella Laurenzo             }
1299436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(segment.size());
1300436c6c9cSStella Laurenzo           }
1301436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1302436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1303436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1304436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1305436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1306436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1307436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
1308436c6c9cSStella Laurenzo                                  err.what() + ")")
1309436c6c9cSStella Laurenzo                                     .str());
1310436c6c9cSStella Laurenzo         }
1311436c6c9cSStella Laurenzo       } else {
1312436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1313436c6c9cSStella Laurenzo       }
1314436c6c9cSStella Laurenzo     }
1315436c6c9cSStella Laurenzo   }
1316436c6c9cSStella Laurenzo 
1317436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
1318436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1319436c6c9cSStella Laurenzo     // Dup.
1320436c6c9cSStella Laurenzo     if (attributes) {
1321436c6c9cSStella Laurenzo       attributes = py::dict(*attributes);
1322436c6c9cSStella Laurenzo     } else {
1323436c6c9cSStella Laurenzo       attributes = py::dict();
1324436c6c9cSStella Laurenzo     }
1325436c6c9cSStella Laurenzo     if (attributes->contains("result_segment_sizes") ||
1326436c6c9cSStella Laurenzo         attributes->contains("operand_segment_sizes")) {
1327436c6c9cSStella Laurenzo       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1328436c6c9cSStella Laurenzo                             "'operand_segment_sizes' attribute is unsupported. "
1329436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
1330436c6c9cSStella Laurenzo     }
1331436c6c9cSStella Laurenzo 
1332436c6c9cSStella Laurenzo     // Add result_segment_sizes attribute.
1333436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
1334436c6c9cSStella Laurenzo       int64_t size = resultSegmentLengths.size();
13358d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
13368d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1337436c6c9cSStella Laurenzo           resultSegmentLengths.size(), resultSegmentLengths.data());
1338436c6c9cSStella Laurenzo       (*attributes)["result_segment_sizes"] =
1339436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1340436c6c9cSStella Laurenzo     }
1341436c6c9cSStella Laurenzo 
1342436c6c9cSStella Laurenzo     // Add operand_segment_sizes attribute.
1343436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
1344436c6c9cSStella Laurenzo       int64_t size = operandSegmentLengths.size();
13458d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
13468d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1347436c6c9cSStella Laurenzo           operandSegmentLengths.size(), operandSegmentLengths.data());
1348436c6c9cSStella Laurenzo       (*attributes)["operand_segment_sizes"] =
1349436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1350436c6c9cSStella Laurenzo     }
1351436c6c9cSStella Laurenzo   }
1352436c6c9cSStella Laurenzo 
1353436c6c9cSStella Laurenzo   // Delegate to create.
1354436c6c9cSStella Laurenzo   return PyOperation::create(std::move(name),
1355436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
1356436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
1357436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
1358436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
1359436c6c9cSStella Laurenzo                              /*regions=*/*regions, location, maybeIp);
1360436c6c9cSStella Laurenzo }
1361436c6c9cSStella Laurenzo 
1362436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject)
1363436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
1364436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
1365436c6c9cSStella Laurenzo     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1366436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
1367436c6c9cSStella Laurenzo 
1368436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) {
1369436c6c9cSStella Laurenzo   // This is... a little gross. The typical pattern is to have a pure python
1370436c6c9cSStella Laurenzo   // class that extends OpView like:
1371436c6c9cSStella Laurenzo   //   class AddFOp(_cext.ir.OpView):
1372436c6c9cSStella Laurenzo   //     def __init__(self, loc, lhs, rhs):
1373436c6c9cSStella Laurenzo   //       operation = loc.context.create_operation(
1374436c6c9cSStella Laurenzo   //           "addf", lhs, rhs, results=[lhs.type])
1375436c6c9cSStella Laurenzo   //       super().__init__(operation)
1376436c6c9cSStella Laurenzo   //
1377436c6c9cSStella Laurenzo   // I.e. The goal of the user facing type is to provide a nice constructor
1378436c6c9cSStella Laurenzo   // that has complete freedom for the op under construction. This is at odds
1379436c6c9cSStella Laurenzo   // with our other desire to sometimes create this object by just passing an
1380436c6c9cSStella Laurenzo   // operation (to initialize the base class). We could do *arg and **kwargs
1381436c6c9cSStella Laurenzo   // munging to try to make it work, but instead, we synthesize a new class
1382436c6c9cSStella Laurenzo   // on the fly which extends this user class (AddFOp in this example) and
1383436c6c9cSStella Laurenzo   // *give it* the base class's __init__ method, thus bypassing the
1384436c6c9cSStella Laurenzo   // intermediate subclass's __init__ method entirely. While slightly,
1385436c6c9cSStella Laurenzo   // underhanded, this is safe/legal because the type hierarchy has not changed
1386436c6c9cSStella Laurenzo   // (we just added a new leaf) and we aren't mucking around with __new__.
1387436c6c9cSStella Laurenzo   // Typically, this new class will be stored on the original as "_Raw" and will
1388436c6c9cSStella Laurenzo   // be used for casts and other things that need a variant of the class that
1389436c6c9cSStella Laurenzo   // is initialized purely from an operation.
1390436c6c9cSStella Laurenzo   py::object parentMetaclass =
1391436c6c9cSStella Laurenzo       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1392436c6c9cSStella Laurenzo   py::dict attributes;
1393436c6c9cSStella Laurenzo   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1394436c6c9cSStella Laurenzo   // now.
1395436c6c9cSStella Laurenzo   //   auto opViewType = py::type::of<PyOpView>();
1396436c6c9cSStella Laurenzo   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1397436c6c9cSStella Laurenzo   attributes["__init__"] = opViewType.attr("__init__");
1398436c6c9cSStella Laurenzo   py::str origName = userClass.attr("__name__");
1399436c6c9cSStella Laurenzo   py::str newName = py::str("_") + origName;
1400436c6c9cSStella Laurenzo   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1401436c6c9cSStella Laurenzo }
1402436c6c9cSStella Laurenzo 
1403436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1404436c6c9cSStella Laurenzo // PyInsertionPoint.
1405436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1406436c6c9cSStella Laurenzo 
1407436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1408436c6c9cSStella Laurenzo 
1409436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1410436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
1411436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
1412436c6c9cSStella Laurenzo 
1413436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1414436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
1415436c6c9cSStella Laurenzo   if (operation.isAttached())
1416436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError,
1417436c6c9cSStella Laurenzo                      "Attempt to insert operation that is already attached");
1418436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
1419436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
1420436c6c9cSStella Laurenzo   if (refOperation) {
1421436c6c9cSStella Laurenzo     // Insert before operation.
1422436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
1423436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
1424436c6c9cSStella Laurenzo   } else {
1425436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
1426436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
1427436c6c9cSStella Laurenzo     // failures later).
1428436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1429436c6c9cSStella Laurenzo       throw py::index_error("Cannot insert operation at the end of a block "
1430436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
1431436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
1432436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
1433436c6c9cSStella Laurenzo     }
1434436c6c9cSStella Laurenzo   }
1435436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1436436c6c9cSStella Laurenzo   operation.setAttached();
1437436c6c9cSStella Laurenzo }
1438436c6c9cSStella Laurenzo 
1439436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1440436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1441436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
1442436c6c9cSStella Laurenzo     // Just insert at end.
1443436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
1444436c6c9cSStella Laurenzo   }
1445436c6c9cSStella Laurenzo 
1446436c6c9cSStella Laurenzo   // Insert before first op.
1447436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
1448436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
1449436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
1450436c6c9cSStella Laurenzo }
1451436c6c9cSStella Laurenzo 
1452436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1453436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1454436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
1455436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1456436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1457436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
1458436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1459436c6c9cSStella Laurenzo }
1460436c6c9cSStella Laurenzo 
1461436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() {
1462436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushInsertionPoint(*this);
1463436c6c9cSStella Laurenzo }
1464436c6c9cSStella Laurenzo 
1465436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType,
1466436c6c9cSStella Laurenzo                                    pybind11::object excVal,
1467436c6c9cSStella Laurenzo                                    pybind11::object excTb) {
1468436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
1469436c6c9cSStella Laurenzo }
1470436c6c9cSStella Laurenzo 
1471436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1472436c6c9cSStella Laurenzo // PyAttribute.
1473436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1474436c6c9cSStella Laurenzo 
1475436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) {
1476436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
1477436c6c9cSStella Laurenzo }
1478436c6c9cSStella Laurenzo 
1479436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() {
1480436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1481436c6c9cSStella Laurenzo }
1482436c6c9cSStella Laurenzo 
1483436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1484436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1485436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
1486436c6c9cSStella Laurenzo     throw py::error_already_set();
1487436c6c9cSStella Laurenzo   return PyAttribute(
1488436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1489436c6c9cSStella Laurenzo }
1490436c6c9cSStella Laurenzo 
1491436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1492436c6c9cSStella Laurenzo // PyNamedAttribute.
1493436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1494436c6c9cSStella Laurenzo 
1495436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1496436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
1497436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
1498436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
1499436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
1500436c6c9cSStella Laurenzo       attr);
1501436c6c9cSStella Laurenzo }
1502436c6c9cSStella Laurenzo 
1503436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1504436c6c9cSStella Laurenzo // PyType.
1505436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1506436c6c9cSStella Laurenzo 
1507436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) {
1508436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
1509436c6c9cSStella Laurenzo }
1510436c6c9cSStella Laurenzo 
1511436c6c9cSStella Laurenzo py::object PyType::getCapsule() {
1512436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1513436c6c9cSStella Laurenzo }
1514436c6c9cSStella Laurenzo 
1515436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) {
1516436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1517436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
1518436c6c9cSStella Laurenzo     throw py::error_already_set();
1519436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1520436c6c9cSStella Laurenzo                 rawType);
1521436c6c9cSStella Laurenzo }
1522436c6c9cSStella Laurenzo 
1523436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1524436c6c9cSStella Laurenzo // PyValue and subclases.
1525436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1526436c6c9cSStella Laurenzo 
15273f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() {
15283f3d1c90SMike Urbach   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
15293f3d1c90SMike Urbach }
15303f3d1c90SMike Urbach 
15313f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) {
15323f3d1c90SMike Urbach   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
15333f3d1c90SMike Urbach   if (mlirValueIsNull(value))
15343f3d1c90SMike Urbach     throw py::error_already_set();
15353f3d1c90SMike Urbach   MlirOperation owner;
15363f3d1c90SMike Urbach   if (mlirValueIsAOpResult(value))
15373f3d1c90SMike Urbach     owner = mlirOpResultGetOwner(value);
15383f3d1c90SMike Urbach   if (mlirValueIsABlockArgument(value))
15393f3d1c90SMike Urbach     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
15403f3d1c90SMike Urbach   if (mlirOperationIsNull(owner))
15413f3d1c90SMike Urbach     throw py::error_already_set();
15423f3d1c90SMike Urbach   MlirContext ctx = mlirOperationGetContext(owner);
15433f3d1c90SMike Urbach   PyOperationRef ownerRef =
15443f3d1c90SMike Urbach       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
15453f3d1c90SMike Urbach   return PyValue(ownerRef, value);
15463f3d1c90SMike Urbach }
15473f3d1c90SMike Urbach 
154830d61893SAlex Zinenko //------------------------------------------------------------------------------
154930d61893SAlex Zinenko // PySymbolTable.
155030d61893SAlex Zinenko //------------------------------------------------------------------------------
155130d61893SAlex Zinenko 
155230d61893SAlex Zinenko PySymbolTable::PySymbolTable(PyOperationBase &operation)
155330d61893SAlex Zinenko     : operation(operation.getOperation().getRef()) {
155430d61893SAlex Zinenko   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
155530d61893SAlex Zinenko   if (mlirSymbolTableIsNull(symbolTable)) {
155630d61893SAlex Zinenko     throw py::cast_error("Operation is not a Symbol Table.");
155730d61893SAlex Zinenko   }
155830d61893SAlex Zinenko }
155930d61893SAlex Zinenko 
156030d61893SAlex Zinenko py::object PySymbolTable::dunderGetItem(const std::string &name) {
156130d61893SAlex Zinenko   operation->checkValid();
156230d61893SAlex Zinenko   MlirOperation symbol = mlirSymbolTableLookup(
156330d61893SAlex Zinenko       symbolTable, mlirStringRefCreate(name.data(), name.length()));
156430d61893SAlex Zinenko   if (mlirOperationIsNull(symbol))
156530d61893SAlex Zinenko     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
156630d61893SAlex Zinenko 
156730d61893SAlex Zinenko   return PyOperation::forOperation(operation->getContext(), symbol,
156830d61893SAlex Zinenko                                    operation.getObject())
156930d61893SAlex Zinenko       ->createOpView();
157030d61893SAlex Zinenko }
157130d61893SAlex Zinenko 
157230d61893SAlex Zinenko void PySymbolTable::erase(PyOperationBase &symbol) {
157330d61893SAlex Zinenko   operation->checkValid();
157430d61893SAlex Zinenko   symbol.getOperation().checkValid();
157530d61893SAlex Zinenko   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
157630d61893SAlex Zinenko   // The operation is also erased, so we must invalidate it. There may be Python
157730d61893SAlex Zinenko   // references to this operation so we don't want to delete it from the list of
157830d61893SAlex Zinenko   // live operations here.
157930d61893SAlex Zinenko   symbol.getOperation().valid = false;
158030d61893SAlex Zinenko }
158130d61893SAlex Zinenko 
158230d61893SAlex Zinenko void PySymbolTable::dunderDel(const std::string &name) {
158330d61893SAlex Zinenko   py::object operation = dunderGetItem(name);
158430d61893SAlex Zinenko   erase(py::cast<PyOperationBase &>(operation));
158530d61893SAlex Zinenko }
158630d61893SAlex Zinenko 
158730d61893SAlex Zinenko PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
158830d61893SAlex Zinenko   operation->checkValid();
158930d61893SAlex Zinenko   symbol.getOperation().checkValid();
159030d61893SAlex Zinenko   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
159130d61893SAlex Zinenko       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
159230d61893SAlex Zinenko   if (mlirAttributeIsNull(symbolAttr))
159330d61893SAlex Zinenko     throw py::value_error("Expected operation to have a symbol name.");
159430d61893SAlex Zinenko   return PyAttribute(
159530d61893SAlex Zinenko       symbol.getOperation().getContext(),
159630d61893SAlex Zinenko       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
159730d61893SAlex Zinenko }
159830d61893SAlex Zinenko 
1599436c6c9cSStella Laurenzo namespace {
1600436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be
1601436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed
1602436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes.
1603436c6c9cSStella Laurenzo template <typename DerivedTy>
1604436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue {
1605436c6c9cSStella Laurenzo public:
1606436c6c9cSStella Laurenzo   // Derived classes must define statics for:
1607436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
1608436c6c9cSStella Laurenzo   //   const char *pyClassName
1609436c6c9cSStella Laurenzo   // and redefine bindDerived.
1610436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, PyValue>;
1611436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirValue);
1612436c6c9cSStella Laurenzo 
1613436c6c9cSStella Laurenzo   PyConcreteValue() = default;
1614436c6c9cSStella Laurenzo   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1615436c6c9cSStella Laurenzo       : PyValue(operationRef, value) {}
1616436c6c9cSStella Laurenzo   PyConcreteValue(PyValue &orig)
1617436c6c9cSStella Laurenzo       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1618436c6c9cSStella Laurenzo 
1619436c6c9cSStella Laurenzo   /// Attempts to cast the original value to the derived type and throws on
1620436c6c9cSStella Laurenzo   /// type mismatches.
1621436c6c9cSStella Laurenzo   static MlirValue castFrom(PyValue &orig) {
1622436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig.get())) {
1623436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1624436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1625436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
1626436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
1627436c6c9cSStella Laurenzo     }
1628436c6c9cSStella Laurenzo     return orig.get();
1629436c6c9cSStella Laurenzo   }
1630436c6c9cSStella Laurenzo 
1631436c6c9cSStella Laurenzo   /// Binds the Python module objects to functions of this class.
1632436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1633f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1634*a6e7d024SStella Laurenzo     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1635*a6e7d024SStella Laurenzo     cls.def_static(
1636*a6e7d024SStella Laurenzo         "isinstance",
1637*a6e7d024SStella Laurenzo         [](PyValue &otherValue) -> bool {
163878f2dae0SAlex Zinenko           return DerivedTy::isaFunction(otherValue);
1639*a6e7d024SStella Laurenzo         },
1640*a6e7d024SStella Laurenzo         py::arg("other_value"));
1641436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
1642436c6c9cSStella Laurenzo   }
1643436c6c9cSStella Laurenzo 
1644436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
1645436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
1646436c6c9cSStella Laurenzo };
1647436c6c9cSStella Laurenzo 
1648436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
1649436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1650436c6c9cSStella Laurenzo public:
1651436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1652436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
1653436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1654436c6c9cSStella Laurenzo 
1655436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1656436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1657436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
1658436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
1659436c6c9cSStella Laurenzo     });
1660436c6c9cSStella Laurenzo     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1661436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
1662436c6c9cSStella Laurenzo     });
1663*a6e7d024SStella Laurenzo     c.def(
1664*a6e7d024SStella Laurenzo         "set_type",
1665*a6e7d024SStella Laurenzo         [](PyBlockArgument &self, PyType type) {
1666436c6c9cSStella Laurenzo           return mlirBlockArgumentSetType(self.get(), type);
1667*a6e7d024SStella Laurenzo         },
1668*a6e7d024SStella Laurenzo         py::arg("type"));
1669436c6c9cSStella Laurenzo   }
1670436c6c9cSStella Laurenzo };
1671436c6c9cSStella Laurenzo 
1672436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult.
1673436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> {
1674436c6c9cSStella Laurenzo public:
1675436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1676436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResult";
1677436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1678436c6c9cSStella Laurenzo 
1679436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1680436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyOpResult &self) {
1681436c6c9cSStella Laurenzo       assert(
1682436c6c9cSStella Laurenzo           mlirOperationEqual(self.getParentOperation()->get(),
1683436c6c9cSStella Laurenzo                              mlirOpResultGetOwner(self.get())) &&
1684436c6c9cSStella Laurenzo           "expected the owner of the value in Python to match that in the IR");
16856ff74f96SMike Urbach       return self.getParentOperation().getObject();
1686436c6c9cSStella Laurenzo     });
1687436c6c9cSStella Laurenzo     c.def_property_readonly("result_number", [](PyOpResult &self) {
1688436c6c9cSStella Laurenzo       return mlirOpResultGetResultNumber(self.get());
1689436c6c9cSStella Laurenzo     });
1690436c6c9cSStella Laurenzo   }
1691436c6c9cSStella Laurenzo };
1692436c6c9cSStella Laurenzo 
1693ed9e52f3SAlex Zinenko /// Returns the list of types of the values held by container.
1694ed9e52f3SAlex Zinenko template <typename Container>
1695ed9e52f3SAlex Zinenko static std::vector<PyType> getValueTypes(Container &container,
1696ed9e52f3SAlex Zinenko                                          PyMlirContextRef &context) {
1697ed9e52f3SAlex Zinenko   std::vector<PyType> result;
1698ed9e52f3SAlex Zinenko   result.reserve(container.getNumElements());
1699ed9e52f3SAlex Zinenko   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1700ed9e52f3SAlex Zinenko     result.push_back(
1701ed9e52f3SAlex Zinenko         PyType(context, mlirValueGetType(container.getElement(i).get())));
1702ed9e52f3SAlex Zinenko   }
1703ed9e52f3SAlex Zinenko   return result;
1704ed9e52f3SAlex Zinenko }
1705ed9e52f3SAlex Zinenko 
1706436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
1707436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
1708436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
1709436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
1710afeda4b9SAlex Zinenko class PyBlockArgumentList
1711afeda4b9SAlex Zinenko     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1712436c6c9cSStella Laurenzo public:
1713afeda4b9SAlex Zinenko   static constexpr const char *pyClassName = "BlockArgumentList";
1714436c6c9cSStella Laurenzo 
1715afeda4b9SAlex Zinenko   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1716afeda4b9SAlex Zinenko                       intptr_t startIndex = 0, intptr_t length = -1,
1717afeda4b9SAlex Zinenko                       intptr_t step = 1)
1718afeda4b9SAlex Zinenko       : Sliceable(startIndex,
1719afeda4b9SAlex Zinenko                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1720afeda4b9SAlex Zinenko                   step),
1721afeda4b9SAlex Zinenko         operation(std::move(operation)), block(block) {}
1722afeda4b9SAlex Zinenko 
1723afeda4b9SAlex Zinenko   /// Returns the number of arguments in the list.
1724afeda4b9SAlex Zinenko   intptr_t getNumElements() {
1725436c6c9cSStella Laurenzo     operation->checkValid();
1726436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
1727436c6c9cSStella Laurenzo   }
1728436c6c9cSStella Laurenzo 
1729afeda4b9SAlex Zinenko   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1730afeda4b9SAlex Zinenko   PyBlockArgument getElement(intptr_t pos) {
1731afeda4b9SAlex Zinenko     MlirValue argument = mlirBlockGetArgument(block, pos);
1732afeda4b9SAlex Zinenko     return PyBlockArgument(operation, argument);
1733436c6c9cSStella Laurenzo   }
1734436c6c9cSStella Laurenzo 
1735afeda4b9SAlex Zinenko   /// Returns a sublist of this list.
1736afeda4b9SAlex Zinenko   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1737afeda4b9SAlex Zinenko                             intptr_t step) {
1738afeda4b9SAlex Zinenko     return PyBlockArgumentList(operation, block, startIndex, length, step);
1739436c6c9cSStella Laurenzo   }
1740436c6c9cSStella Laurenzo 
1741ed9e52f3SAlex Zinenko   static void bindDerived(ClassTy &c) {
1742ed9e52f3SAlex Zinenko     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1743ed9e52f3SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
1744ed9e52f3SAlex Zinenko     });
1745ed9e52f3SAlex Zinenko   }
1746ed9e52f3SAlex Zinenko 
1747436c6c9cSStella Laurenzo private:
1748436c6c9cSStella Laurenzo   PyOperationRef operation;
1749436c6c9cSStella Laurenzo   MlirBlock block;
1750436c6c9cSStella Laurenzo };
1751436c6c9cSStella Laurenzo 
1752436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
1753436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1754436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1755436c6c9cSStella Laurenzo /// operation.
1756436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1757436c6c9cSStella Laurenzo public:
1758436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
1759436c6c9cSStella Laurenzo 
1760436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1761436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
1762436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1763436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1764436c6c9cSStella Laurenzo                                : length,
1765436c6c9cSStella Laurenzo                   step),
1766436c6c9cSStella Laurenzo         operation(operation) {}
1767436c6c9cSStella Laurenzo 
1768436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1769436c6c9cSStella Laurenzo     operation->checkValid();
1770436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
1771436c6c9cSStella Laurenzo   }
1772436c6c9cSStella Laurenzo 
1773436c6c9cSStella Laurenzo   PyValue getElement(intptr_t pos) {
17745664c5e2SJohn Demme     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
17755664c5e2SJohn Demme     MlirOperation owner;
17765664c5e2SJohn Demme     if (mlirValueIsAOpResult(operand))
17775664c5e2SJohn Demme       owner = mlirOpResultGetOwner(operand);
17785664c5e2SJohn Demme     else if (mlirValueIsABlockArgument(operand))
17795664c5e2SJohn Demme       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
17805664c5e2SJohn Demme     else
17815664c5e2SJohn Demme       assert(false && "Value must be an block arg or op result.");
17825664c5e2SJohn Demme     PyOperationRef pyOwner =
17835664c5e2SJohn Demme         PyOperation::forOperation(operation->getContext(), owner);
17845664c5e2SJohn Demme     return PyValue(pyOwner, operand);
1785436c6c9cSStella Laurenzo   }
1786436c6c9cSStella Laurenzo 
1787436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1788436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
1789436c6c9cSStella Laurenzo   }
1790436c6c9cSStella Laurenzo 
179163d16d06SMike Urbach   void dunderSetItem(intptr_t index, PyValue value) {
179263d16d06SMike Urbach     index = wrapIndex(index);
179363d16d06SMike Urbach     mlirOperationSetOperand(operation->get(), index, value.get());
179463d16d06SMike Urbach   }
179563d16d06SMike Urbach 
179663d16d06SMike Urbach   static void bindDerived(ClassTy &c) {
179763d16d06SMike Urbach     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
179863d16d06SMike Urbach   }
179963d16d06SMike Urbach 
1800436c6c9cSStella Laurenzo private:
1801436c6c9cSStella Laurenzo   PyOperationRef operation;
1802436c6c9cSStella Laurenzo };
1803436c6c9cSStella Laurenzo 
1804436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive
1805436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1806436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1807436c6c9cSStella Laurenzo /// operation.
1808436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1809436c6c9cSStella Laurenzo public:
1810436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResultList";
1811436c6c9cSStella Laurenzo 
1812436c6c9cSStella Laurenzo   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1813436c6c9cSStella Laurenzo                  intptr_t length = -1, intptr_t step = 1)
1814436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1815436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumResults(operation->get())
1816436c6c9cSStella Laurenzo                                : length,
1817436c6c9cSStella Laurenzo                   step),
1818436c6c9cSStella Laurenzo         operation(operation) {}
1819436c6c9cSStella Laurenzo 
1820436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1821436c6c9cSStella Laurenzo     operation->checkValid();
1822436c6c9cSStella Laurenzo     return mlirOperationGetNumResults(operation->get());
1823436c6c9cSStella Laurenzo   }
1824436c6c9cSStella Laurenzo 
1825436c6c9cSStella Laurenzo   PyOpResult getElement(intptr_t index) {
1826436c6c9cSStella Laurenzo     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1827436c6c9cSStella Laurenzo     return PyOpResult(value);
1828436c6c9cSStella Laurenzo   }
1829436c6c9cSStella Laurenzo 
1830436c6c9cSStella Laurenzo   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1831436c6c9cSStella Laurenzo     return PyOpResultList(operation, startIndex, length, step);
1832436c6c9cSStella Laurenzo   }
1833436c6c9cSStella Laurenzo 
1834ed9e52f3SAlex Zinenko   static void bindDerived(ClassTy &c) {
1835ed9e52f3SAlex Zinenko     c.def_property_readonly("types", [](PyOpResultList &self) {
1836ed9e52f3SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
1837ed9e52f3SAlex Zinenko     });
1838ed9e52f3SAlex Zinenko   }
1839ed9e52f3SAlex Zinenko 
1840436c6c9cSStella Laurenzo private:
1841436c6c9cSStella Laurenzo   PyOperationRef operation;
1842436c6c9cSStella Laurenzo };
1843436c6c9cSStella Laurenzo 
1844436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
1845436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
1846436c6c9cSStella Laurenzo class PyOpAttributeMap {
1847436c6c9cSStella Laurenzo public:
1848436c6c9cSStella Laurenzo   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1849436c6c9cSStella Laurenzo 
1850436c6c9cSStella Laurenzo   PyAttribute dunderGetItemNamed(const std::string &name) {
1851436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1852436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
1853436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1854436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1855436c6c9cSStella Laurenzo                        "attempt to access a non-existent attribute");
1856436c6c9cSStella Laurenzo     }
1857436c6c9cSStella Laurenzo     return PyAttribute(operation->getContext(), attr);
1858436c6c9cSStella Laurenzo   }
1859436c6c9cSStella Laurenzo 
1860436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1861436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1862436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1863436c6c9cSStella Laurenzo                        "attempt to access out of bounds attribute");
1864436c6c9cSStella Laurenzo     }
1865436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
1866436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
1867436c6c9cSStella Laurenzo     return PyNamedAttribute(
1868436c6c9cSStella Laurenzo         namedAttr.attribute,
1869120591e1SRiver Riddle         std::string(mlirIdentifierStr(namedAttr.name).data,
1870120591e1SRiver Riddle                     mlirIdentifierStr(namedAttr.name).length));
1871436c6c9cSStella Laurenzo   }
1872436c6c9cSStella Laurenzo 
1873436c6c9cSStella Laurenzo   void dunderSetItem(const std::string &name, PyAttribute attr) {
1874436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1875436c6c9cSStella Laurenzo                                     attr);
1876436c6c9cSStella Laurenzo   }
1877436c6c9cSStella Laurenzo 
1878436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
1879436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1880436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
1881436c6c9cSStella Laurenzo     if (!removed)
1882436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1883436c6c9cSStella Laurenzo                        "attempt to delete a non-existent attribute");
1884436c6c9cSStella Laurenzo   }
1885436c6c9cSStella Laurenzo 
1886436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1887436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
1888436c6c9cSStella Laurenzo   }
1889436c6c9cSStella Laurenzo 
1890436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
1891436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1892436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
1893436c6c9cSStella Laurenzo   }
1894436c6c9cSStella Laurenzo 
1895436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1896f05ff4f7SStella Laurenzo     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1897436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
1898436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
1899436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1900436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1901436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1902436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1903436c6c9cSStella Laurenzo   }
1904436c6c9cSStella Laurenzo 
1905436c6c9cSStella Laurenzo private:
1906436c6c9cSStella Laurenzo   PyOperationRef operation;
1907436c6c9cSStella Laurenzo };
1908436c6c9cSStella Laurenzo 
1909436c6c9cSStella Laurenzo } // end namespace
1910436c6c9cSStella Laurenzo 
1911436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1912436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
1913436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1914436c6c9cSStella Laurenzo 
1915436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) {
1916436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
19174acd8457SAlex Zinenko   // Mapping of MlirContext.
1918436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1919f05ff4f7SStella Laurenzo   py::class_<PyMlirContext>(m, "Context", py::module_local())
1920436c6c9cSStella Laurenzo       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1921436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1922436c6c9cSStella Laurenzo       .def("_get_context_again",
1923436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
1924436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1925436c6c9cSStella Laurenzo              return ref.releaseObject();
1926436c6c9cSStella Laurenzo            })
1927436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1928436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1929436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1930436c6c9cSStella Laurenzo                              &PyMlirContext::getCapsule)
1931436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1932436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
1933436c6c9cSStella Laurenzo       .def("__exit__", &PyMlirContext::contextExit)
1934436c6c9cSStella Laurenzo       .def_property_readonly_static(
1935436c6c9cSStella Laurenzo           "current",
1936436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1937436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
1938436c6c9cSStella Laurenzo             if (!context)
1939436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Context");
1940436c6c9cSStella Laurenzo             return context;
1941436c6c9cSStella Laurenzo           },
1942436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
1943436c6c9cSStella Laurenzo       .def_property_readonly(
1944436c6c9cSStella Laurenzo           "dialects",
1945436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1946436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
1947436c6c9cSStella Laurenzo       .def_property_readonly(
1948436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1949436c6c9cSStella Laurenzo           "Alias for 'dialect'")
1950436c6c9cSStella Laurenzo       .def(
1951436c6c9cSStella Laurenzo           "get_dialect_descriptor",
1952436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
1953436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
1954436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
1955436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
1956436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
1957436c6c9cSStella Laurenzo                                Twine("Dialect '") + name + "' not found");
1958436c6c9cSStella Laurenzo             }
1959436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
1960436c6c9cSStella Laurenzo           },
1961*a6e7d024SStella Laurenzo           py::arg("dialect_name"),
1962436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
1963436c6c9cSStella Laurenzo       .def_property(
1964436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
1965436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
1966436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
1967436c6c9cSStella Laurenzo           },
1968436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
1969436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
19709a9214faSStella Laurenzo           })
1971*a6e7d024SStella Laurenzo       .def(
1972*a6e7d024SStella Laurenzo           "enable_multithreading",
1973caa159f0SNicolas Vasilache           [](PyMlirContext &self, bool enable) {
1974caa159f0SNicolas Vasilache             mlirContextEnableMultithreading(self.get(), enable);
1975*a6e7d024SStella Laurenzo           },
1976*a6e7d024SStella Laurenzo           py::arg("enable"))
1977*a6e7d024SStella Laurenzo       .def(
1978*a6e7d024SStella Laurenzo           "is_registered_operation",
19799a9214faSStella Laurenzo           [](PyMlirContext &self, std::string &name) {
19809a9214faSStella Laurenzo             return mlirContextIsRegisteredOperation(
19819a9214faSStella Laurenzo                 self.get(), MlirStringRef{name.data(), name.size()});
1982*a6e7d024SStella Laurenzo           },
1983*a6e7d024SStella Laurenzo           py::arg("operation_name"));
1984436c6c9cSStella Laurenzo 
1985436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1986436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
1987436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1988f05ff4f7SStella Laurenzo   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1989436c6c9cSStella Laurenzo       .def_property_readonly("namespace",
1990436c6c9cSStella Laurenzo                              [](PyDialectDescriptor &self) {
1991436c6c9cSStella Laurenzo                                MlirStringRef ns =
1992436c6c9cSStella Laurenzo                                    mlirDialectGetNamespace(self.get());
1993436c6c9cSStella Laurenzo                                return py::str(ns.data, ns.length);
1994436c6c9cSStella Laurenzo                              })
1995436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
1996436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1997436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
1998436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
1999436c6c9cSStella Laurenzo         repr.append(">");
2000436c6c9cSStella Laurenzo         return repr;
2001436c6c9cSStella Laurenzo       });
2002436c6c9cSStella Laurenzo 
2003436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2004436c6c9cSStella Laurenzo   // Mapping of PyDialects
2005436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2006f05ff4f7SStella Laurenzo   py::class_<PyDialects>(m, "Dialects", py::module_local())
2007436c6c9cSStella Laurenzo       .def("__getitem__",
2008436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
2009436c6c9cSStella Laurenzo              MlirDialect dialect =
2010436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
2011436c6c9cSStella Laurenzo              py::object descriptor =
2012436c6c9cSStella Laurenzo                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2013436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
2014436c6c9cSStella Laurenzo            })
2015436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2016436c6c9cSStella Laurenzo         MlirDialect dialect =
2017436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
2018436c6c9cSStella Laurenzo         py::object descriptor =
2019436c6c9cSStella Laurenzo             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2020436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
2021436c6c9cSStella Laurenzo       });
2022436c6c9cSStella Laurenzo 
2023436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2024436c6c9cSStella Laurenzo   // Mapping of PyDialect
2025436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2026f05ff4f7SStella Laurenzo   py::class_<PyDialect>(m, "Dialect", py::module_local())
2027*a6e7d024SStella Laurenzo       .def(py::init<py::object>(), py::arg("descriptor"))
2028436c6c9cSStella Laurenzo       .def_property_readonly(
2029436c6c9cSStella Laurenzo           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2030436c6c9cSStella Laurenzo       .def("__repr__", [](py::object self) {
2031436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
2032436c6c9cSStella Laurenzo         return py::str("<Dialect ") +
2033436c6c9cSStella Laurenzo                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2034436c6c9cSStella Laurenzo                clazz.attr("__module__") + py::str(".") +
2035436c6c9cSStella Laurenzo                clazz.attr("__name__") + py::str(")>");
2036436c6c9cSStella Laurenzo       });
2037436c6c9cSStella Laurenzo 
2038436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2039436c6c9cSStella Laurenzo   // Mapping of Location
2040436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2041f05ff4f7SStella Laurenzo   py::class_<PyLocation>(m, "Location", py::module_local())
2042436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2043436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2044436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
2045436c6c9cSStella Laurenzo       .def("__exit__", &PyLocation::contextExit)
2046436c6c9cSStella Laurenzo       .def("__eq__",
2047436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
2048436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
2049436c6c9cSStella Laurenzo            })
2050436c6c9cSStella Laurenzo       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2051436c6c9cSStella Laurenzo       .def_property_readonly_static(
2052436c6c9cSStella Laurenzo           "current",
2053436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2054436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
2055436c6c9cSStella Laurenzo             if (!loc)
2056436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Location");
2057436c6c9cSStella Laurenzo             return loc;
2058436c6c9cSStella Laurenzo           },
2059436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
2060436c6c9cSStella Laurenzo       .def_static(
2061436c6c9cSStella Laurenzo           "unknown",
2062436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
2063436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
2064436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
2065436c6c9cSStella Laurenzo           },
2066436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
2067436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
2068436c6c9cSStella Laurenzo       .def_static(
2069e67cbbefSJacques Pienaar           "callsite",
2070e67cbbefSJacques Pienaar           [](PyLocation callee, const std::vector<PyLocation> &frames,
2071e67cbbefSJacques Pienaar              DefaultingPyMlirContext context) {
2072e67cbbefSJacques Pienaar             if (frames.empty())
2073e67cbbefSJacques Pienaar               throw py::value_error("No caller frames provided");
2074e67cbbefSJacques Pienaar             MlirLocation caller = frames.back().get();
2075e2f16be5SMehdi Amini             for (const PyLocation &frame :
2076e67cbbefSJacques Pienaar                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2077e67cbbefSJacques Pienaar               caller = mlirLocationCallSiteGet(frame.get(), caller);
2078e67cbbefSJacques Pienaar             return PyLocation(context->getRef(),
2079e67cbbefSJacques Pienaar                               mlirLocationCallSiteGet(callee.get(), caller));
2080e67cbbefSJacques Pienaar           },
2081e67cbbefSJacques Pienaar           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2082e67cbbefSJacques Pienaar           kContextGetCallSiteLocationDocstring)
2083e67cbbefSJacques Pienaar       .def_static(
2084436c6c9cSStella Laurenzo           "file",
2085436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
2086436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
2087436c6c9cSStella Laurenzo             return PyLocation(
2088436c6c9cSStella Laurenzo                 context->getRef(),
2089436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
2090436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
2091436c6c9cSStella Laurenzo           },
2092436c6c9cSStella Laurenzo           py::arg("filename"), py::arg("line"), py::arg("col"),
2093436c6c9cSStella Laurenzo           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
209404d76d36SJacques Pienaar       .def_static(
209504d76d36SJacques Pienaar           "name",
209604d76d36SJacques Pienaar           [](std::string name, llvm::Optional<PyLocation> childLoc,
209704d76d36SJacques Pienaar              DefaultingPyMlirContext context) {
209804d76d36SJacques Pienaar             return PyLocation(
209904d76d36SJacques Pienaar                 context->getRef(),
210004d76d36SJacques Pienaar                 mlirLocationNameGet(
210104d76d36SJacques Pienaar                     context->get(), toMlirStringRef(name),
210204d76d36SJacques Pienaar                     childLoc ? childLoc->get()
210304d76d36SJacques Pienaar                              : mlirLocationUnknownGet(context->get())));
210404d76d36SJacques Pienaar           },
210504d76d36SJacques Pienaar           py::arg("name"), py::arg("childLoc") = py::none(),
210604d76d36SJacques Pienaar           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2107436c6c9cSStella Laurenzo       .def_property_readonly(
2108436c6c9cSStella Laurenzo           "context",
2109436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
2110436c6c9cSStella Laurenzo           "Context that owns the Location")
2111436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
2112436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2113436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
2114436c6c9cSStella Laurenzo                           printAccum.getUserData());
2115436c6c9cSStella Laurenzo         return printAccum.join();
2116436c6c9cSStella Laurenzo       });
2117436c6c9cSStella Laurenzo 
2118436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2119436c6c9cSStella Laurenzo   // Mapping of Module
2120436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2121f05ff4f7SStella Laurenzo   py::class_<PyModule>(m, "Module", py::module_local())
2122436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2123436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2124436c6c9cSStella Laurenzo       .def_static(
2125436c6c9cSStella Laurenzo           "parse",
2126436c6c9cSStella Laurenzo           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2127436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
2128436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
2129436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2130436c6c9cSStella Laurenzo             // in C API.
2131436c6c9cSStella Laurenzo             if (mlirModuleIsNull(module)) {
2132436c6c9cSStella Laurenzo               throw SetPyError(
2133436c6c9cSStella Laurenzo                   PyExc_ValueError,
2134436c6c9cSStella Laurenzo                   "Unable to parse module assembly (see diagnostics)");
2135436c6c9cSStella Laurenzo             }
2136436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2137436c6c9cSStella Laurenzo           },
2138436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2139436c6c9cSStella Laurenzo           kModuleParseDocstring)
2140436c6c9cSStella Laurenzo       .def_static(
2141436c6c9cSStella Laurenzo           "create",
2142436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
2143436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
2144436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2145436c6c9cSStella Laurenzo           },
2146436c6c9cSStella Laurenzo           py::arg("loc") = py::none(), "Creates an empty module")
2147436c6c9cSStella Laurenzo       .def_property_readonly(
2148436c6c9cSStella Laurenzo           "context",
2149436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
2150436c6c9cSStella Laurenzo           "Context that created the Module")
2151436c6c9cSStella Laurenzo       .def_property_readonly(
2152436c6c9cSStella Laurenzo           "operation",
2153436c6c9cSStella Laurenzo           [](PyModule &self) {
2154436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
2155436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
2156436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
2157436c6c9cSStella Laurenzo                 .releaseObject();
2158436c6c9cSStella Laurenzo           },
2159436c6c9cSStella Laurenzo           "Accesses the module as an operation")
2160436c6c9cSStella Laurenzo       .def_property_readonly(
2161436c6c9cSStella Laurenzo           "body",
2162436c6c9cSStella Laurenzo           [](PyModule &self) {
2163436c6c9cSStella Laurenzo             PyOperationRef module_op = PyOperation::forOperation(
2164436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
2165436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
2166436c6c9cSStella Laurenzo             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2167436c6c9cSStella Laurenzo             return returnBlock;
2168436c6c9cSStella Laurenzo           },
2169436c6c9cSStella Laurenzo           "Return the block for this module")
2170436c6c9cSStella Laurenzo       .def(
2171436c6c9cSStella Laurenzo           "dump",
2172436c6c9cSStella Laurenzo           [](PyModule &self) {
2173436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
2174436c6c9cSStella Laurenzo           },
2175436c6c9cSStella Laurenzo           kDumpDocstring)
2176436c6c9cSStella Laurenzo       .def(
2177436c6c9cSStella Laurenzo           "__str__",
2178ace1d0adSStella Laurenzo           [](py::object self) {
2179ace1d0adSStella Laurenzo             // Defer to the operation's __str__.
2180ace1d0adSStella Laurenzo             return self.attr("operation").attr("__str__")();
2181436c6c9cSStella Laurenzo           },
2182436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
2183436c6c9cSStella Laurenzo 
2184436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2185436c6c9cSStella Laurenzo   // Mapping of Operation.
2186436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2187f05ff4f7SStella Laurenzo   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
21881fb2e842SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
21891fb2e842SStella Laurenzo                              [](PyOperationBase &self) {
21901fb2e842SStella Laurenzo                                return self.getOperation().getCapsule();
21911fb2e842SStella Laurenzo                              })
2192436c6c9cSStella Laurenzo       .def("__eq__",
2193436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
2194436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
2195436c6c9cSStella Laurenzo            })
2196436c6c9cSStella Laurenzo       .def("__eq__",
2197436c6c9cSStella Laurenzo            [](PyOperationBase &self, py::object other) { return false; })
2198f78fe0b7Srkayaith       .def("__hash__",
2199f78fe0b7Srkayaith            [](PyOperationBase &self) {
2200f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2201f78fe0b7Srkayaith            })
2202436c6c9cSStella Laurenzo       .def_property_readonly("attributes",
2203436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2204436c6c9cSStella Laurenzo                                return PyOpAttributeMap(
2205436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2206436c6c9cSStella Laurenzo                              })
2207436c6c9cSStella Laurenzo       .def_property_readonly("operands",
2208436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2209436c6c9cSStella Laurenzo                                return PyOpOperandList(
2210436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2211436c6c9cSStella Laurenzo                              })
2212436c6c9cSStella Laurenzo       .def_property_readonly("regions",
2213436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2214436c6c9cSStella Laurenzo                                return PyRegionList(
2215436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2216436c6c9cSStella Laurenzo                              })
2217436c6c9cSStella Laurenzo       .def_property_readonly(
2218436c6c9cSStella Laurenzo           "results",
2219436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2220436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
2221436c6c9cSStella Laurenzo           },
2222436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
2223436c6c9cSStella Laurenzo       .def_property_readonly(
2224436c6c9cSStella Laurenzo           "result",
2225436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2226436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
2227436c6c9cSStella Laurenzo             auto numResults = mlirOperationGetNumResults(operation);
2228436c6c9cSStella Laurenzo             if (numResults != 1) {
2229436c6c9cSStella Laurenzo               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2230436c6c9cSStella Laurenzo               throw SetPyError(
2231436c6c9cSStella Laurenzo                   PyExc_ValueError,
2232436c6c9cSStella Laurenzo                   Twine("Cannot call .result on operation ") +
2233436c6c9cSStella Laurenzo                       StringRef(name.data, name.length) + " which has " +
2234436c6c9cSStella Laurenzo                       Twine(numResults) +
2235436c6c9cSStella Laurenzo                       " results (it is only valid for operations with a "
2236436c6c9cSStella Laurenzo                       "single result)");
2237436c6c9cSStella Laurenzo             }
2238436c6c9cSStella Laurenzo             return PyOpResult(operation.getRef(),
2239436c6c9cSStella Laurenzo                               mlirOperationGetResult(operation, 0));
2240436c6c9cSStella Laurenzo           },
2241436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
2242436c6c9cSStella Laurenzo           "otherwise).")
2243d5429a13Srkayaith       .def_property_readonly(
2244d5429a13Srkayaith           "location",
2245d5429a13Srkayaith           [](PyOperationBase &self) {
2246d5429a13Srkayaith             PyOperation &operation = self.getOperation();
2247d5429a13Srkayaith             return PyLocation(operation.getContext(),
2248d5429a13Srkayaith                               mlirOperationGetLocation(operation.get()));
2249d5429a13Srkayaith           },
2250d5429a13Srkayaith           "Returns the source location the operation was defined or derived "
2251d5429a13Srkayaith           "from.")
2252436c6c9cSStella Laurenzo       .def(
2253436c6c9cSStella Laurenzo           "__str__",
2254436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2255436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
2256436c6c9cSStella Laurenzo                                /*largeElementsLimit=*/llvm::None,
2257436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
2258436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
2259436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
2260ace1d0adSStella Laurenzo                                /*useLocalScope=*/false,
2261ace1d0adSStella Laurenzo                                /*assumeVerified=*/false);
2262436c6c9cSStella Laurenzo           },
2263436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
2264436c6c9cSStella Laurenzo       .def("print", &PyOperationBase::print,
2265436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
2266436c6c9cSStella Laurenzo            py::arg("file") = py::none(), py::arg("binary") = false,
2267436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2268436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2269436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2270436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2271ace1d0adSStella Laurenzo            py::arg("use_local_scope") = false,
2272ace1d0adSStella Laurenzo            py::arg("assume_verified") = false, kOperationPrintDocstring)
2273436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
2274436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
2275436c6c9cSStella Laurenzo            py::arg("binary") = false,
2276436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2277436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2278436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2279436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2280ace1d0adSStella Laurenzo            py::arg("use_local_scope") = false,
2281ace1d0adSStella Laurenzo            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2282436c6c9cSStella Laurenzo       .def(
2283436c6c9cSStella Laurenzo           "verify",
2284436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2285436c6c9cSStella Laurenzo             return mlirOperationVerify(self.getOperation());
2286436c6c9cSStella Laurenzo           },
2287436c6c9cSStella Laurenzo           "Verify the operation and return true if it passes, false if it "
228824685aaeSAlex Zinenko           "fails.")
228924685aaeSAlex Zinenko       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
229024685aaeSAlex Zinenko            "Puts self immediately after the other operation in its parent "
229124685aaeSAlex Zinenko            "block.")
229224685aaeSAlex Zinenko       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
229324685aaeSAlex Zinenko            "Puts self immediately before the other operation in its parent "
229424685aaeSAlex Zinenko            "block.")
229524685aaeSAlex Zinenko       .def(
229624685aaeSAlex Zinenko           "detach_from_parent",
229724685aaeSAlex Zinenko           [](PyOperationBase &self) {
229824685aaeSAlex Zinenko             PyOperation &operation = self.getOperation();
229924685aaeSAlex Zinenko             operation.checkValid();
230024685aaeSAlex Zinenko             if (!operation.isAttached())
230124685aaeSAlex Zinenko               throw py::value_error("Detached operation has no parent.");
230224685aaeSAlex Zinenko 
230324685aaeSAlex Zinenko             operation.detachFromParent();
230424685aaeSAlex Zinenko             return operation.createOpView();
230524685aaeSAlex Zinenko           },
230624685aaeSAlex Zinenko           "Detaches the operation from its parent block.");
2307436c6c9cSStella Laurenzo 
2308f05ff4f7SStella Laurenzo   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2309436c6c9cSStella Laurenzo       .def_static("create", &PyOperation::create, py::arg("name"),
2310436c6c9cSStella Laurenzo                   py::arg("results") = py::none(),
2311436c6c9cSStella Laurenzo                   py::arg("operands") = py::none(),
2312436c6c9cSStella Laurenzo                   py::arg("attributes") = py::none(),
2313436c6c9cSStella Laurenzo                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2314436c6c9cSStella Laurenzo                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2315436c6c9cSStella Laurenzo                   kOperationCreateDocstring)
2316c65bb760SJohn Demme       .def_property_readonly("parent",
23171689dadeSJohn Demme                              [](PyOperation &self) -> py::object {
23181689dadeSJohn Demme                                auto parent = self.getParentOperation();
23191689dadeSJohn Demme                                if (parent)
23201689dadeSJohn Demme                                  return parent->getObject();
23211689dadeSJohn Demme                                return py::none();
2322c65bb760SJohn Demme                              })
232349745f87SMike Urbach       .def("erase", &PyOperation::erase)
23240126e906SJohn Demme       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
23250126e906SJohn Demme                              &PyOperation::getCapsule)
23260126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2327436c6c9cSStella Laurenzo       .def_property_readonly("name",
2328436c6c9cSStella Laurenzo                              [](PyOperation &self) {
232949745f87SMike Urbach                                self.checkValid();
2330436c6c9cSStella Laurenzo                                MlirOperation operation = self.get();
2331436c6c9cSStella Laurenzo                                MlirStringRef name = mlirIdentifierStr(
2332436c6c9cSStella Laurenzo                                    mlirOperationGetName(operation));
2333436c6c9cSStella Laurenzo                                return py::str(name.data, name.length);
2334436c6c9cSStella Laurenzo                              })
2335436c6c9cSStella Laurenzo       .def_property_readonly(
2336436c6c9cSStella Laurenzo           "context",
233749745f87SMike Urbach           [](PyOperation &self) {
233849745f87SMike Urbach             self.checkValid();
233949745f87SMike Urbach             return self.getContext().getObject();
234049745f87SMike Urbach           },
2341436c6c9cSStella Laurenzo           "Context that owns the Operation")
2342436c6c9cSStella Laurenzo       .def_property_readonly("opview", &PyOperation::createOpView);
2343436c6c9cSStella Laurenzo 
2344436c6c9cSStella Laurenzo   auto opViewClass =
2345f05ff4f7SStella Laurenzo       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2346*a6e7d024SStella Laurenzo           .def(py::init<py::object>(), py::arg("operation"))
2347436c6c9cSStella Laurenzo           .def_property_readonly("operation", &PyOpView::getOperationObject)
2348436c6c9cSStella Laurenzo           .def_property_readonly(
2349436c6c9cSStella Laurenzo               "context",
2350436c6c9cSStella Laurenzo               [](PyOpView &self) {
2351436c6c9cSStella Laurenzo                 return self.getOperation().getContext().getObject();
2352436c6c9cSStella Laurenzo               },
2353436c6c9cSStella Laurenzo               "Context that owns the Operation")
2354436c6c9cSStella Laurenzo           .def("__str__", [](PyOpView &self) {
2355436c6c9cSStella Laurenzo             return py::str(self.getOperationObject());
2356436c6c9cSStella Laurenzo           });
2357436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2358436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2359436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2360436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
2361436c6c9cSStella Laurenzo       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2362436c6c9cSStella Laurenzo       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2363436c6c9cSStella Laurenzo       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2364436c6c9cSStella Laurenzo       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2365436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
2366436c6c9cSStella Laurenzo 
2367436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2368436c6c9cSStella Laurenzo   // Mapping of PyRegion.
2369436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2370f05ff4f7SStella Laurenzo   py::class_<PyRegion>(m, "Region", py::module_local())
2371436c6c9cSStella Laurenzo       .def_property_readonly(
2372436c6c9cSStella Laurenzo           "blocks",
2373436c6c9cSStella Laurenzo           [](PyRegion &self) {
2374436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
2375436c6c9cSStella Laurenzo           },
2376436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
237778f2dae0SAlex Zinenko       .def_property_readonly(
237878f2dae0SAlex Zinenko           "owner",
237978f2dae0SAlex Zinenko           [](PyRegion &self) {
238078f2dae0SAlex Zinenko             return self.getParentOperation()->createOpView();
238178f2dae0SAlex Zinenko           },
238278f2dae0SAlex Zinenko           "Returns the operation owning this region.")
2383436c6c9cSStella Laurenzo       .def(
2384436c6c9cSStella Laurenzo           "__iter__",
2385436c6c9cSStella Laurenzo           [](PyRegion &self) {
2386436c6c9cSStella Laurenzo             self.checkValid();
2387436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2388436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
2389436c6c9cSStella Laurenzo           },
2390436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
2391436c6c9cSStella Laurenzo       .def("__eq__",
2392436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
2393436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2394436c6c9cSStella Laurenzo            })
2395436c6c9cSStella Laurenzo       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2396436c6c9cSStella Laurenzo 
2397436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2398436c6c9cSStella Laurenzo   // Mapping of PyBlock.
2399436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2400f05ff4f7SStella Laurenzo   py::class_<PyBlock>(m, "Block", py::module_local())
2401436c6c9cSStella Laurenzo       .def_property_readonly(
240296fbd5cdSJohn Demme           "owner",
240396fbd5cdSJohn Demme           [](PyBlock &self) {
240496fbd5cdSJohn Demme             return self.getParentOperation()->createOpView();
240596fbd5cdSJohn Demme           },
240696fbd5cdSJohn Demme           "Returns the owning operation of this block.")
240796fbd5cdSJohn Demme       .def_property_readonly(
24088e6c55c9SStella Laurenzo           "region",
24098e6c55c9SStella Laurenzo           [](PyBlock &self) {
24108e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
24118e6c55c9SStella Laurenzo             return PyRegion(self.getParentOperation(), region);
24128e6c55c9SStella Laurenzo           },
24138e6c55c9SStella Laurenzo           "Returns the owning region of this block.")
24148e6c55c9SStella Laurenzo       .def_property_readonly(
2415436c6c9cSStella Laurenzo           "arguments",
2416436c6c9cSStella Laurenzo           [](PyBlock &self) {
2417436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
2418436c6c9cSStella Laurenzo           },
2419436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
2420436c6c9cSStella Laurenzo       .def_property_readonly(
2421436c6c9cSStella Laurenzo           "operations",
2422436c6c9cSStella Laurenzo           [](PyBlock &self) {
2423436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
2424436c6c9cSStella Laurenzo           },
2425436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
242678f2dae0SAlex Zinenko       .def_static(
242778f2dae0SAlex Zinenko           "create_at_start",
242878f2dae0SAlex Zinenko           [](PyRegion &parent, py::list pyArgTypes) {
242978f2dae0SAlex Zinenko             parent.checkValid();
243078f2dae0SAlex Zinenko             llvm::SmallVector<MlirType, 4> argTypes;
243178f2dae0SAlex Zinenko             argTypes.reserve(pyArgTypes.size());
243278f2dae0SAlex Zinenko             for (auto &pyArg : pyArgTypes) {
243378f2dae0SAlex Zinenko               argTypes.push_back(pyArg.cast<PyType &>());
243478f2dae0SAlex Zinenko             }
243578f2dae0SAlex Zinenko 
243678f2dae0SAlex Zinenko             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
243778f2dae0SAlex Zinenko             mlirRegionInsertOwnedBlock(parent, 0, block);
243878f2dae0SAlex Zinenko             return PyBlock(parent.getParentOperation(), block);
243978f2dae0SAlex Zinenko           },
2440*a6e7d024SStella Laurenzo           py::arg("parent"), py::arg("arg_types") = py::list(),
244178f2dae0SAlex Zinenko           "Creates and returns a new Block at the beginning of the given "
244278f2dae0SAlex Zinenko           "region (with given argument types).")
2443436c6c9cSStella Laurenzo       .def(
24448e6c55c9SStella Laurenzo           "create_before",
24458e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
24468e6c55c9SStella Laurenzo             self.checkValid();
24478e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
24488e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
24498e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
24508e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
24518e6c55c9SStella Laurenzo             }
24528e6c55c9SStella Laurenzo 
24538e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
24548e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
24558e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
24568e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
24578e6c55c9SStella Laurenzo           },
24588e6c55c9SStella Laurenzo           "Creates and returns a new Block before this block "
24598e6c55c9SStella Laurenzo           "(with given argument types).")
24608e6c55c9SStella Laurenzo       .def(
24618e6c55c9SStella Laurenzo           "create_after",
24628e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
24638e6c55c9SStella Laurenzo             self.checkValid();
24648e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
24658e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
24668e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
24678e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
24688e6c55c9SStella Laurenzo             }
24698e6c55c9SStella Laurenzo 
24708e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
24718e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
24728e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
24738e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
24748e6c55c9SStella Laurenzo           },
24758e6c55c9SStella Laurenzo           "Creates and returns a new Block after this block "
24768e6c55c9SStella Laurenzo           "(with given argument types).")
24778e6c55c9SStella Laurenzo       .def(
2478436c6c9cSStella Laurenzo           "__iter__",
2479436c6c9cSStella Laurenzo           [](PyBlock &self) {
2480436c6c9cSStella Laurenzo             self.checkValid();
2481436c6c9cSStella Laurenzo             MlirOperation firstOperation =
2482436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
2483436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
2484436c6c9cSStella Laurenzo                                        firstOperation);
2485436c6c9cSStella Laurenzo           },
2486436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
2487436c6c9cSStella Laurenzo       .def("__eq__",
2488436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
2489436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2490436c6c9cSStella Laurenzo            })
2491436c6c9cSStella Laurenzo       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2492436c6c9cSStella Laurenzo       .def(
2493436c6c9cSStella Laurenzo           "__str__",
2494436c6c9cSStella Laurenzo           [](PyBlock &self) {
2495436c6c9cSStella Laurenzo             self.checkValid();
2496436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2497436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
2498436c6c9cSStella Laurenzo                            printAccum.getUserData());
2499436c6c9cSStella Laurenzo             return printAccum.join();
2500436c6c9cSStella Laurenzo           },
250124685aaeSAlex Zinenko           "Returns the assembly form of the block.")
250224685aaeSAlex Zinenko       .def(
250324685aaeSAlex Zinenko           "append",
250424685aaeSAlex Zinenko           [](PyBlock &self, PyOperationBase &operation) {
250524685aaeSAlex Zinenko             if (operation.getOperation().isAttached())
250624685aaeSAlex Zinenko               operation.getOperation().detachFromParent();
250724685aaeSAlex Zinenko 
250824685aaeSAlex Zinenko             MlirOperation mlirOperation = operation.getOperation().get();
250924685aaeSAlex Zinenko             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
251024685aaeSAlex Zinenko             operation.getOperation().setAttached(
251124685aaeSAlex Zinenko                 self.getParentOperation().getObject());
251224685aaeSAlex Zinenko           },
2513*a6e7d024SStella Laurenzo           py::arg("operation"),
251424685aaeSAlex Zinenko           "Appends an operation to this block. If the operation is currently "
251524685aaeSAlex Zinenko           "in another block, it will be moved.");
2516436c6c9cSStella Laurenzo 
2517436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2518436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
2519436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2520436c6c9cSStella Laurenzo 
2521f05ff4f7SStella Laurenzo   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2522436c6c9cSStella Laurenzo       .def(py::init<PyBlock &>(), py::arg("block"),
2523436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
2524436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
2525436c6c9cSStella Laurenzo       .def("__exit__", &PyInsertionPoint::contextExit)
2526436c6c9cSStella Laurenzo       .def_property_readonly_static(
2527436c6c9cSStella Laurenzo           "current",
2528436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2529436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2530436c6c9cSStella Laurenzo             if (!ip)
2531436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2532436c6c9cSStella Laurenzo             return ip;
2533436c6c9cSStella Laurenzo           },
2534436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
2535436c6c9cSStella Laurenzo           "ValueError if none has been set")
2536436c6c9cSStella Laurenzo       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2537436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
2538436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2539436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts at the beginning of the block.")
2540436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2541436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts before the block terminator.")
2542436c6c9cSStella Laurenzo       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
25438e6c55c9SStella Laurenzo            "Inserts an operation.")
25448e6c55c9SStella Laurenzo       .def_property_readonly(
25458e6c55c9SStella Laurenzo           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
25468e6c55c9SStella Laurenzo           "Returns the block that this InsertionPoint points to.");
2547436c6c9cSStella Laurenzo 
2548436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2549436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
2550436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2551f05ff4f7SStella Laurenzo   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2552b57d6fe4SStella Laurenzo       // Delegate to the PyAttribute copy constructor, which will also lifetime
2553b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirAttribute.
2554b57d6fe4SStella Laurenzo       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2555b57d6fe4SStella Laurenzo            "Casts the passed attribute to the generic Attribute")
2556436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2557436c6c9cSStella Laurenzo                              &PyAttribute::getCapsule)
2558436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2559436c6c9cSStella Laurenzo       .def_static(
2560436c6c9cSStella Laurenzo           "parse",
2561436c6c9cSStella Laurenzo           [](std::string attrSpec, DefaultingPyMlirContext context) {
2562436c6c9cSStella Laurenzo             MlirAttribute type = mlirAttributeParseGet(
2563436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
2564436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2565436c6c9cSStella Laurenzo             // in C API.
2566436c6c9cSStella Laurenzo             if (mlirAttributeIsNull(type)) {
2567436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2568436c6c9cSStella Laurenzo                                Twine("Unable to parse attribute: '") +
2569436c6c9cSStella Laurenzo                                    attrSpec + "'");
2570436c6c9cSStella Laurenzo             }
2571436c6c9cSStella Laurenzo             return PyAttribute(context->getRef(), type);
2572436c6c9cSStella Laurenzo           },
2573436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2574436c6c9cSStella Laurenzo           "Parses an attribute from an assembly form")
2575436c6c9cSStella Laurenzo       .def_property_readonly(
2576436c6c9cSStella Laurenzo           "context",
2577436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
2578436c6c9cSStella Laurenzo           "Context that owns the Attribute")
2579436c6c9cSStella Laurenzo       .def_property_readonly("type",
2580436c6c9cSStella Laurenzo                              [](PyAttribute &self) {
2581436c6c9cSStella Laurenzo                                return PyType(self.getContext()->getRef(),
2582436c6c9cSStella Laurenzo                                              mlirAttributeGetType(self));
2583436c6c9cSStella Laurenzo                              })
2584436c6c9cSStella Laurenzo       .def(
2585436c6c9cSStella Laurenzo           "get_named",
2586436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
2587436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
2588436c6c9cSStella Laurenzo           },
2589436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2590436c6c9cSStella Laurenzo       .def("__eq__",
2591436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2592436c6c9cSStella Laurenzo       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2593f78fe0b7Srkayaith       .def("__hash__",
2594f78fe0b7Srkayaith            [](PyAttribute &self) {
2595f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2596f78fe0b7Srkayaith            })
2597436c6c9cSStella Laurenzo       .def(
2598436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2599436c6c9cSStella Laurenzo           kDumpDocstring)
2600436c6c9cSStella Laurenzo       .def(
2601436c6c9cSStella Laurenzo           "__str__",
2602436c6c9cSStella Laurenzo           [](PyAttribute &self) {
2603436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2604436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
2605436c6c9cSStella Laurenzo                                printAccum.getUserData());
2606436c6c9cSStella Laurenzo             return printAccum.join();
2607436c6c9cSStella Laurenzo           },
2608436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
2609436c6c9cSStella Laurenzo       .def("__repr__", [](PyAttribute &self) {
2610436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2611436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2612436c6c9cSStella Laurenzo         // However, attribute values are generally considered useful and are
2613436c6c9cSStella Laurenzo         // printed. This may need to be re-evaluated if debug dumps end up
2614436c6c9cSStella Laurenzo         // being excessive.
2615436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2616436c6c9cSStella Laurenzo         printAccum.parts.append("Attribute(");
2617436c6c9cSStella Laurenzo         mlirAttributePrint(self, printAccum.getCallback(),
2618436c6c9cSStella Laurenzo                            printAccum.getUserData());
2619436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2620436c6c9cSStella Laurenzo         return printAccum.join();
2621436c6c9cSStella Laurenzo       });
2622436c6c9cSStella Laurenzo 
2623436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2624436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
2625436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2626f05ff4f7SStella Laurenzo   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2627436c6c9cSStella Laurenzo       .def("__repr__",
2628436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
2629436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
2630436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
2631436c6c9cSStella Laurenzo              printAccum.parts.append(
2632120591e1SRiver Riddle                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2633120591e1SRiver Riddle                          mlirIdentifierStr(self.namedAttr.name).length));
2634436c6c9cSStella Laurenzo              printAccum.parts.append("=");
2635436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
2636436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
2637436c6c9cSStella Laurenzo                                 printAccum.getUserData());
2638436c6c9cSStella Laurenzo              printAccum.parts.append(")");
2639436c6c9cSStella Laurenzo              return printAccum.join();
2640436c6c9cSStella Laurenzo            })
2641436c6c9cSStella Laurenzo       .def_property_readonly(
2642436c6c9cSStella Laurenzo           "name",
2643436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2644436c6c9cSStella Laurenzo             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2645436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
2646436c6c9cSStella Laurenzo           },
2647436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
2648436c6c9cSStella Laurenzo       .def_property_readonly(
2649436c6c9cSStella Laurenzo           "attr",
2650436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2651436c6c9cSStella Laurenzo             // TODO: When named attribute is removed/refactored, also remove
2652436c6c9cSStella Laurenzo             // this constructor (it does an inefficient table lookup).
2653436c6c9cSStella Laurenzo             auto contextRef = PyMlirContext::forContext(
2654436c6c9cSStella Laurenzo                 mlirAttributeGetContext(self.namedAttr.attribute));
2655436c6c9cSStella Laurenzo             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2656436c6c9cSStella Laurenzo           },
2657436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(),
2658436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
2659436c6c9cSStella Laurenzo 
2660436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2661436c6c9cSStella Laurenzo   // Mapping of PyType.
2662436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2663f05ff4f7SStella Laurenzo   py::class_<PyType>(m, "Type", py::module_local())
2664b57d6fe4SStella Laurenzo       // Delegate to the PyType copy constructor, which will also lifetime
2665b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirType.
2666b57d6fe4SStella Laurenzo       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2667b57d6fe4SStella Laurenzo            "Casts the passed type to the generic Type")
2668436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2669436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2670436c6c9cSStella Laurenzo       .def_static(
2671436c6c9cSStella Laurenzo           "parse",
2672436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
2673436c6c9cSStella Laurenzo             MlirType type =
2674436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2675436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2676436c6c9cSStella Laurenzo             // in C API.
2677436c6c9cSStella Laurenzo             if (mlirTypeIsNull(type)) {
2678436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2679436c6c9cSStella Laurenzo                                Twine("Unable to parse type: '") + typeSpec +
2680436c6c9cSStella Laurenzo                                    "'");
2681436c6c9cSStella Laurenzo             }
2682436c6c9cSStella Laurenzo             return PyType(context->getRef(), type);
2683436c6c9cSStella Laurenzo           },
2684436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2685436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
2686436c6c9cSStella Laurenzo       .def_property_readonly(
2687436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
2688436c6c9cSStella Laurenzo           "Context that owns the Type")
2689436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2690436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2691f78fe0b7Srkayaith       .def("__hash__",
2692f78fe0b7Srkayaith            [](PyType &self) {
2693f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2694f78fe0b7Srkayaith            })
2695436c6c9cSStella Laurenzo       .def(
2696436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2697436c6c9cSStella Laurenzo       .def(
2698436c6c9cSStella Laurenzo           "__str__",
2699436c6c9cSStella Laurenzo           [](PyType &self) {
2700436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2701436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
2702436c6c9cSStella Laurenzo                           printAccum.getUserData());
2703436c6c9cSStella Laurenzo             return printAccum.join();
2704436c6c9cSStella Laurenzo           },
2705436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
2706436c6c9cSStella Laurenzo       .def("__repr__", [](PyType &self) {
2707436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2708436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2709436c6c9cSStella Laurenzo         // However, types are an exception as they typically have compact
2710436c6c9cSStella Laurenzo         // assembly forms and printing them is useful.
2711436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2712436c6c9cSStella Laurenzo         printAccum.parts.append("Type(");
2713436c6c9cSStella Laurenzo         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2714436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2715436c6c9cSStella Laurenzo         return printAccum.join();
2716436c6c9cSStella Laurenzo       });
2717436c6c9cSStella Laurenzo 
2718436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2719436c6c9cSStella Laurenzo   // Mapping of Value.
2720436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2721f05ff4f7SStella Laurenzo   py::class_<PyValue>(m, "Value", py::module_local())
27223f3d1c90SMike Urbach       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
27233f3d1c90SMike Urbach       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2724436c6c9cSStella Laurenzo       .def_property_readonly(
2725436c6c9cSStella Laurenzo           "context",
2726436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2727436c6c9cSStella Laurenzo           "Context in which the value lives.")
2728436c6c9cSStella Laurenzo       .def(
2729436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2730436c6c9cSStella Laurenzo           kDumpDocstring)
27315664c5e2SJohn Demme       .def_property_readonly(
27325664c5e2SJohn Demme           "owner",
27335664c5e2SJohn Demme           [](PyValue &self) {
27345664c5e2SJohn Demme             assert(mlirOperationEqual(self.getParentOperation()->get(),
27355664c5e2SJohn Demme                                       mlirOpResultGetOwner(self.get())) &&
27365664c5e2SJohn Demme                    "expected the owner of the value in Python to match that in "
27375664c5e2SJohn Demme                    "the IR");
27385664c5e2SJohn Demme             return self.getParentOperation().getObject();
27395664c5e2SJohn Demme           })
2740436c6c9cSStella Laurenzo       .def("__eq__",
2741436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
2742436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2743436c6c9cSStella Laurenzo            })
2744436c6c9cSStella Laurenzo       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2745f78fe0b7Srkayaith       .def("__hash__",
2746f78fe0b7Srkayaith            [](PyValue &self) {
2747f78fe0b7Srkayaith              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2748f78fe0b7Srkayaith            })
2749436c6c9cSStella Laurenzo       .def(
2750436c6c9cSStella Laurenzo           "__str__",
2751436c6c9cSStella Laurenzo           [](PyValue &self) {
2752436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2753436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
2754436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
2755436c6c9cSStella Laurenzo                            printAccum.getUserData());
2756436c6c9cSStella Laurenzo             printAccum.parts.append(")");
2757436c6c9cSStella Laurenzo             return printAccum.join();
2758436c6c9cSStella Laurenzo           },
2759436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
2760436c6c9cSStella Laurenzo       .def_property_readonly("type", [](PyValue &self) {
2761436c6c9cSStella Laurenzo         return PyType(self.getParentOperation()->getContext(),
2762436c6c9cSStella Laurenzo                       mlirValueGetType(self.get()));
2763436c6c9cSStella Laurenzo       });
2764436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
2765436c6c9cSStella Laurenzo   PyOpResult::bind(m);
2766436c6c9cSStella Laurenzo 
276730d61893SAlex Zinenko   //----------------------------------------------------------------------------
276830d61893SAlex Zinenko   // Mapping of SymbolTable.
276930d61893SAlex Zinenko   //----------------------------------------------------------------------------
277030d61893SAlex Zinenko   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
277130d61893SAlex Zinenko       .def(py::init<PyOperationBase &>())
277230d61893SAlex Zinenko       .def("__getitem__", &PySymbolTable::dunderGetItem)
2773*a6e7d024SStella Laurenzo       .def("insert", &PySymbolTable::insert, py::arg("operation"))
2774*a6e7d024SStella Laurenzo       .def("erase", &PySymbolTable::erase, py::arg("operation"))
277530d61893SAlex Zinenko       .def("__delitem__", &PySymbolTable::dunderDel)
277630d61893SAlex Zinenko       .def("__contains__", [](PySymbolTable &table, const std::string &name) {
277730d61893SAlex Zinenko         return !mlirOperationIsNull(mlirSymbolTableLookup(
277830d61893SAlex Zinenko             table, mlirStringRefCreate(name.data(), name.length())));
277930d61893SAlex Zinenko       });
278030d61893SAlex Zinenko 
2781436c6c9cSStella Laurenzo   // Container bindings.
2782436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
2783436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
2784436c6c9cSStella Laurenzo   PyBlockList::bind(m);
2785436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
2786436c6c9cSStella Laurenzo   PyOperationList::bind(m);
2787436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
2788436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
2789436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
2790436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
2791436c6c9cSStella Laurenzo   PyRegionList::bind(m);
27924acd8457SAlex Zinenko 
27934acd8457SAlex Zinenko   // Debug bindings.
27944acd8457SAlex Zinenko   PyGlobalDebugFlag::bind(m);
2795436c6c9cSStella Laurenzo }
2796