1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "Globals.h"
12436c6c9cSStella Laurenzo #include "PybindUtils.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
174acd8457SAlex Zinenko #include "mlir-c/Debug.h"
183f3d1c90SMike Urbach #include "mlir-c/IR.h"
19436c6c9cSStella Laurenzo #include "mlir-c/Registration.h"
20*e67cbbefSJacques Pienaar #include "llvm/ADT/ArrayRef.h"
21436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
22436c6c9cSStella Laurenzo #include <pybind11/stl.h>
23436c6c9cSStella Laurenzo 
24436c6c9cSStella Laurenzo namespace py = pybind11;
25436c6c9cSStella Laurenzo using namespace mlir;
26436c6c9cSStella Laurenzo using namespace mlir::python;
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo using llvm::SmallVector;
29436c6c9cSStella Laurenzo using llvm::StringRef;
30436c6c9cSStella Laurenzo using llvm::Twine;
31436c6c9cSStella Laurenzo 
32436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
33436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
34436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
35436c6c9cSStella Laurenzo 
36436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
37436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
38436c6c9cSStella Laurenzo 
39436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed.
40436c6c9cSStella Laurenzo 
41436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
42436c6c9cSStella Laurenzo )";
43436c6c9cSStella Laurenzo 
44*e67cbbefSJacques Pienaar static const char kContextGetCallSiteLocationDocstring[] =
45*e67cbbefSJacques Pienaar     R"(Gets a Location representing a caller and callsite)";
46*e67cbbefSJacques Pienaar 
47436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
48436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
49436c6c9cSStella Laurenzo 
5004d76d36SJacques Pienaar static const char kContextGetNameLocationDocString[] =
5104d76d36SJacques Pienaar     R"(Gets a Location representing a named location with optional child location)";
5204d76d36SJacques Pienaar 
53436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
54436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
55436c6c9cSStella Laurenzo 
56436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails.
57436c6c9cSStella Laurenzo 
58436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
59436c6c9cSStella Laurenzo )";
60436c6c9cSStella Laurenzo 
61436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
62436c6c9cSStella Laurenzo     R"(Creates a new operation.
63436c6c9cSStella Laurenzo 
64436c6c9cSStella Laurenzo Args:
65436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
66436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
67436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
68436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
69436c6c9cSStella Laurenzo   regions: Number of regions to create.
70436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
71436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
72436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
73436c6c9cSStella Laurenzo     context manager).
74436c6c9cSStella Laurenzo Returns:
75436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
76436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
77436c6c9cSStella Laurenzo )";
78436c6c9cSStella Laurenzo 
79436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
80436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
81436c6c9cSStella Laurenzo 
82436c6c9cSStella Laurenzo Args:
83436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
84436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
85436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
86436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
87436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
88436c6c9cSStella Laurenzo     to False.
89436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
90436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
91436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
92436c6c9cSStella Laurenzo     ops. Defaults to False.
93436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
94436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
95436c6c9cSStella Laurenzo     module prints.
96436c6c9cSStella Laurenzo )";
97436c6c9cSStella Laurenzo 
98436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
99436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
100436c6c9cSStella Laurenzo 
101436c6c9cSStella Laurenzo Args:
102436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
103436c6c9cSStella Laurenzo     False.
104436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
105436c6c9cSStella Laurenzo     configuring the printout.
106436c6c9cSStella Laurenzo Returns:
107436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
108436c6c9cSStella Laurenzo   argument.
109436c6c9cSStella Laurenzo )";
110436c6c9cSStella Laurenzo 
111436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
112436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
113436c6c9cSStella Laurenzo 
114436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
115436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
116436c6c9cSStella Laurenzo customize behavior.
117436c6c9cSStella Laurenzo )";
118436c6c9cSStella Laurenzo 
119436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
120436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
123436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo Returns:
126436c6c9cSStella Laurenzo   The created block.
127436c6c9cSStella Laurenzo )";
128436c6c9cSStella Laurenzo 
129436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
130436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
131436c6c9cSStella Laurenzo 
132436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
133436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
134436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
135436c6c9cSStella Laurenzo )";
136436c6c9cSStella Laurenzo 
137436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
138436c6c9cSStella Laurenzo // Utilities.
139436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
140436c6c9cSStella Laurenzo 
1414acd8457SAlex Zinenko /// Helper for creating an @classmethod.
142436c6c9cSStella Laurenzo template <class Func, typename... Args>
143436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) {
144436c6c9cSStella Laurenzo   py::object cf = py::cpp_function(f, args...);
145436c6c9cSStella Laurenzo   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
146436c6c9cSStella Laurenzo }
147436c6c9cSStella Laurenzo 
148436c6c9cSStella Laurenzo static py::object
149436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
150436c6c9cSStella Laurenzo                            py::object dialectDescriptor) {
151436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
152436c6c9cSStella Laurenzo   if (!dialectClass) {
153436c6c9cSStella Laurenzo     // Use the base class.
154436c6c9cSStella Laurenzo     return py::cast(PyDialect(std::move(dialectDescriptor)));
155436c6c9cSStella Laurenzo   }
156436c6c9cSStella Laurenzo 
157436c6c9cSStella Laurenzo   // Create the custom implementation.
158436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
159436c6c9cSStella Laurenzo }
160436c6c9cSStella Laurenzo 
161436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
162436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
163436c6c9cSStella Laurenzo }
164436c6c9cSStella Laurenzo 
1654acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag.
1664acd8457SAlex Zinenko struct PyGlobalDebugFlag {
1674acd8457SAlex Zinenko   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
1684acd8457SAlex Zinenko 
1694acd8457SAlex Zinenko   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
1704acd8457SAlex Zinenko 
1714acd8457SAlex Zinenko   static void bind(py::module &m) {
1724acd8457SAlex Zinenko     // Debug flags.
173f05ff4f7SStella Laurenzo     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
1744acd8457SAlex Zinenko         .def_property_static("flag", &PyGlobalDebugFlag::get,
1754acd8457SAlex Zinenko                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
1764acd8457SAlex Zinenko   }
1774acd8457SAlex Zinenko };
1784acd8457SAlex Zinenko 
179436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
180436c6c9cSStella Laurenzo // Collections.
181436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
182436c6c9cSStella Laurenzo 
183436c6c9cSStella Laurenzo namespace {
184436c6c9cSStella Laurenzo 
185436c6c9cSStella Laurenzo class PyRegionIterator {
186436c6c9cSStella Laurenzo public:
187436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
188436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
189436c6c9cSStella Laurenzo 
190436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
191436c6c9cSStella Laurenzo 
192436c6c9cSStella Laurenzo   PyRegion dunderNext() {
193436c6c9cSStella Laurenzo     operation->checkValid();
194436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195436c6c9cSStella Laurenzo       throw py::stop_iteration();
196436c6c9cSStella Laurenzo     }
197436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198436c6c9cSStella Laurenzo     return PyRegion(operation, region);
199436c6c9cSStella Laurenzo   }
200436c6c9cSStella Laurenzo 
201436c6c9cSStella Laurenzo   static void bind(py::module &m) {
202f05ff4f7SStella Laurenzo     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
203436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
204436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
205436c6c9cSStella Laurenzo   }
206436c6c9cSStella Laurenzo 
207436c6c9cSStella Laurenzo private:
208436c6c9cSStella Laurenzo   PyOperationRef operation;
209436c6c9cSStella Laurenzo   int nextIndex = 0;
210436c6c9cSStella Laurenzo };
211436c6c9cSStella Laurenzo 
212436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
213436c6c9cSStella Laurenzo /// with a sequence-like container.
214436c6c9cSStella Laurenzo class PyRegionList {
215436c6c9cSStella Laurenzo public:
216436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217436c6c9cSStella Laurenzo 
218436c6c9cSStella Laurenzo   intptr_t dunderLen() {
219436c6c9cSStella Laurenzo     operation->checkValid();
220436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
221436c6c9cSStella Laurenzo   }
222436c6c9cSStella Laurenzo 
223436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
224436c6c9cSStella Laurenzo     // dunderLen checks validity.
225436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
226436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
227436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
228436c6c9cSStella Laurenzo     }
229436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230436c6c9cSStella Laurenzo     return PyRegion(operation, region);
231436c6c9cSStella Laurenzo   }
232436c6c9cSStella Laurenzo 
233436c6c9cSStella Laurenzo   static void bind(py::module &m) {
234f05ff4f7SStella Laurenzo     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
235436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
236436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
237436c6c9cSStella Laurenzo   }
238436c6c9cSStella Laurenzo 
239436c6c9cSStella Laurenzo private:
240436c6c9cSStella Laurenzo   PyOperationRef operation;
241436c6c9cSStella Laurenzo };
242436c6c9cSStella Laurenzo 
243436c6c9cSStella Laurenzo class PyBlockIterator {
244436c6c9cSStella Laurenzo public:
245436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
247436c6c9cSStella Laurenzo 
248436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
249436c6c9cSStella Laurenzo 
250436c6c9cSStella Laurenzo   PyBlock dunderNext() {
251436c6c9cSStella Laurenzo     operation->checkValid();
252436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
253436c6c9cSStella Laurenzo       throw py::stop_iteration();
254436c6c9cSStella Laurenzo     }
255436c6c9cSStella Laurenzo 
256436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
257436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
258436c6c9cSStella Laurenzo     return returnBlock;
259436c6c9cSStella Laurenzo   }
260436c6c9cSStella Laurenzo 
261436c6c9cSStella Laurenzo   static void bind(py::module &m) {
262f05ff4f7SStella Laurenzo     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
263436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
264436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
265436c6c9cSStella Laurenzo   }
266436c6c9cSStella Laurenzo 
267436c6c9cSStella Laurenzo private:
268436c6c9cSStella Laurenzo   PyOperationRef operation;
269436c6c9cSStella Laurenzo   MlirBlock next;
270436c6c9cSStella Laurenzo };
271436c6c9cSStella Laurenzo 
272436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
274436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
275436c6c9cSStella Laurenzo class PyBlockList {
276436c6c9cSStella Laurenzo public:
277436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
278436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
279436c6c9cSStella Laurenzo 
280436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
281436c6c9cSStella Laurenzo     operation->checkValid();
282436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283436c6c9cSStella Laurenzo   }
284436c6c9cSStella Laurenzo 
285436c6c9cSStella Laurenzo   intptr_t dunderLen() {
286436c6c9cSStella Laurenzo     operation->checkValid();
287436c6c9cSStella Laurenzo     intptr_t count = 0;
288436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
289436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
290436c6c9cSStella Laurenzo       count += 1;
291436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
292436c6c9cSStella Laurenzo     }
293436c6c9cSStella Laurenzo     return count;
294436c6c9cSStella Laurenzo   }
295436c6c9cSStella Laurenzo 
296436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
297436c6c9cSStella Laurenzo     operation->checkValid();
298436c6c9cSStella Laurenzo     if (index < 0) {
299436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
300436c6c9cSStella Laurenzo                        "attempt to access out of bounds block");
301436c6c9cSStella Laurenzo     }
302436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
303436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
304436c6c9cSStella Laurenzo       if (index == 0) {
305436c6c9cSStella Laurenzo         return PyBlock(operation, block);
306436c6c9cSStella Laurenzo       }
307436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
308436c6c9cSStella Laurenzo       index -= 1;
309436c6c9cSStella Laurenzo     }
310436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311436c6c9cSStella Laurenzo   }
312436c6c9cSStella Laurenzo 
313436c6c9cSStella Laurenzo   PyBlock appendBlock(py::args pyArgTypes) {
314436c6c9cSStella Laurenzo     operation->checkValid();
315436c6c9cSStella Laurenzo     llvm::SmallVector<MlirType, 4> argTypes;
316436c6c9cSStella Laurenzo     argTypes.reserve(pyArgTypes.size());
317436c6c9cSStella Laurenzo     for (auto &pyArg : pyArgTypes) {
318436c6c9cSStella Laurenzo       argTypes.push_back(pyArg.cast<PyType &>());
319436c6c9cSStella Laurenzo     }
320436c6c9cSStella Laurenzo 
321436c6c9cSStella Laurenzo     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
323436c6c9cSStella Laurenzo     return PyBlock(operation, block);
324436c6c9cSStella Laurenzo   }
325436c6c9cSStella Laurenzo 
326436c6c9cSStella Laurenzo   static void bind(py::module &m) {
327f05ff4f7SStella Laurenzo     py::class_<PyBlockList>(m, "BlockList", py::module_local())
328436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
329436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
330436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
331436c6c9cSStella Laurenzo         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332436c6c9cSStella Laurenzo   }
333436c6c9cSStella Laurenzo 
334436c6c9cSStella Laurenzo private:
335436c6c9cSStella Laurenzo   PyOperationRef operation;
336436c6c9cSStella Laurenzo   MlirRegion region;
337436c6c9cSStella Laurenzo };
338436c6c9cSStella Laurenzo 
339436c6c9cSStella Laurenzo class PyOperationIterator {
340436c6c9cSStella Laurenzo public:
341436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
343436c6c9cSStella Laurenzo 
344436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
345436c6c9cSStella Laurenzo 
346436c6c9cSStella Laurenzo   py::object dunderNext() {
347436c6c9cSStella Laurenzo     parentOperation->checkValid();
348436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
349436c6c9cSStella Laurenzo       throw py::stop_iteration();
350436c6c9cSStella Laurenzo     }
351436c6c9cSStella Laurenzo 
352436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
353436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
354436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
355436c6c9cSStella Laurenzo     return returnOperation->createOpView();
356436c6c9cSStella Laurenzo   }
357436c6c9cSStella Laurenzo 
358436c6c9cSStella Laurenzo   static void bind(py::module &m) {
359f05ff4f7SStella Laurenzo     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
360436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
361436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
362436c6c9cSStella Laurenzo   }
363436c6c9cSStella Laurenzo 
364436c6c9cSStella Laurenzo private:
365436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
366436c6c9cSStella Laurenzo   MlirOperation next;
367436c6c9cSStella Laurenzo };
368436c6c9cSStella Laurenzo 
369436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
370436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
371436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
372436c6c9cSStella Laurenzo /// by a block.
373436c6c9cSStella Laurenzo class PyOperationList {
374436c6c9cSStella Laurenzo public:
375436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
377436c6c9cSStella Laurenzo 
378436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
379436c6c9cSStella Laurenzo     parentOperation->checkValid();
380436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
381436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
382436c6c9cSStella Laurenzo   }
383436c6c9cSStella Laurenzo 
384436c6c9cSStella Laurenzo   intptr_t dunderLen() {
385436c6c9cSStella Laurenzo     parentOperation->checkValid();
386436c6c9cSStella Laurenzo     intptr_t count = 0;
387436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
389436c6c9cSStella Laurenzo       count += 1;
390436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
391436c6c9cSStella Laurenzo     }
392436c6c9cSStella Laurenzo     return count;
393436c6c9cSStella Laurenzo   }
394436c6c9cSStella Laurenzo 
395436c6c9cSStella Laurenzo   py::object dunderGetItem(intptr_t index) {
396436c6c9cSStella Laurenzo     parentOperation->checkValid();
397436c6c9cSStella Laurenzo     if (index < 0) {
398436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
399436c6c9cSStella Laurenzo                        "attempt to access out of bounds operation");
400436c6c9cSStella Laurenzo     }
401436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
403436c6c9cSStella Laurenzo       if (index == 0) {
404436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405436c6c9cSStella Laurenzo             ->createOpView();
406436c6c9cSStella Laurenzo       }
407436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
408436c6c9cSStella Laurenzo       index -= 1;
409436c6c9cSStella Laurenzo     }
410436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError,
411436c6c9cSStella Laurenzo                      "attempt to access out of bounds operation");
412436c6c9cSStella Laurenzo   }
413436c6c9cSStella Laurenzo 
414436c6c9cSStella Laurenzo   static void bind(py::module &m) {
415f05ff4f7SStella Laurenzo     py::class_<PyOperationList>(m, "OperationList", py::module_local())
416436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
417436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
418436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
419436c6c9cSStella Laurenzo   }
420436c6c9cSStella Laurenzo 
421436c6c9cSStella Laurenzo private:
422436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
423436c6c9cSStella Laurenzo   MlirBlock block;
424436c6c9cSStella Laurenzo };
425436c6c9cSStella Laurenzo 
426436c6c9cSStella Laurenzo } // namespace
427436c6c9cSStella Laurenzo 
428436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
429436c6c9cSStella Laurenzo // PyMlirContext
430436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
431436c6c9cSStella Laurenzo 
432436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
434436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
435436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
436436c6c9cSStella Laurenzo }
437436c6c9cSStella Laurenzo 
438436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
439436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
440436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
441436c6c9cSStella Laurenzo   // liveContexts.
442436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
443436c6c9cSStella Laurenzo   getLiveContexts().erase(context.ptr);
444436c6c9cSStella Laurenzo   mlirContextDestroy(context);
445436c6c9cSStella Laurenzo }
446436c6c9cSStella Laurenzo 
447436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() {
448436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449436c6c9cSStella Laurenzo }
450436c6c9cSStella Laurenzo 
451436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) {
452436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
454436c6c9cSStella Laurenzo     throw py::error_already_set();
455436c6c9cSStella Laurenzo   return forContext(rawContext).releaseObject();
456436c6c9cSStella Laurenzo }
457436c6c9cSStella Laurenzo 
458436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() {
459436c6c9cSStella Laurenzo   MlirContext context = mlirContextCreate();
460436c6c9cSStella Laurenzo   mlirRegisterAllDialects(context);
461436c6c9cSStella Laurenzo   return new PyMlirContext(context);
462436c6c9cSStella Laurenzo }
463436c6c9cSStella Laurenzo 
464436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
466436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
467436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
468436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
469436c6c9cSStella Laurenzo     // Create.
470436c6c9cSStella Laurenzo     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471436c6c9cSStella Laurenzo     py::object pyRef = py::cast(unownedContextWrapper);
472436c6c9cSStella Laurenzo     assert(pyRef && "cast to py::object failed");
473436c6c9cSStella Laurenzo     liveContexts[context.ptr] = unownedContextWrapper;
474436c6c9cSStella Laurenzo     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475436c6c9cSStella Laurenzo   }
476436c6c9cSStella Laurenzo   // Use existing.
477436c6c9cSStella Laurenzo   py::object pyRef = py::cast(it->second);
478436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
479436c6c9cSStella Laurenzo }
480436c6c9cSStella Laurenzo 
481436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
483436c6c9cSStella Laurenzo   return liveContexts;
484436c6c9cSStella Laurenzo }
485436c6c9cSStella Laurenzo 
486436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487436c6c9cSStella Laurenzo 
488436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489436c6c9cSStella Laurenzo 
490436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491436c6c9cSStella Laurenzo 
492436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() {
493436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushContext(*this);
494436c6c9cSStella Laurenzo }
495436c6c9cSStella Laurenzo 
496436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType,
497436c6c9cSStella Laurenzo                                 pybind11::object excVal,
498436c6c9cSStella Laurenzo                                 pybind11::object excTb) {
499436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
500436c6c9cSStella Laurenzo }
501436c6c9cSStella Laurenzo 
502436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
503436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504436c6c9cSStella Laurenzo   if (!context) {
505436c6c9cSStella Laurenzo     throw SetPyError(
506436c6c9cSStella Laurenzo         PyExc_RuntimeError,
507436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
508436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
509436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
510436c6c9cSStella Laurenzo   }
511436c6c9cSStella Laurenzo   return *context;
512436c6c9cSStella Laurenzo }
513436c6c9cSStella Laurenzo 
514436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
515436c6c9cSStella Laurenzo // PyThreadContextEntry management
516436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
517436c6c9cSStella Laurenzo 
518436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
520436c6c9cSStella Laurenzo   return stack;
521436c6c9cSStella Laurenzo }
522436c6c9cSStella Laurenzo 
523436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524436c6c9cSStella Laurenzo   auto &stack = getStack();
525436c6c9cSStella Laurenzo   if (stack.empty())
526436c6c9cSStella Laurenzo     return nullptr;
527436c6c9cSStella Laurenzo   return &stack.back();
528436c6c9cSStella Laurenzo }
529436c6c9cSStella Laurenzo 
530436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531436c6c9cSStella Laurenzo                                 py::object insertionPoint,
532436c6c9cSStella Laurenzo                                 py::object location) {
533436c6c9cSStella Laurenzo   auto &stack = getStack();
534436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535436c6c9cSStella Laurenzo                      std::move(location));
536436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
537436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
538436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
539436c6c9cSStella Laurenzo   if (stack.size() > 1) {
540436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
541436c6c9cSStella Laurenzo     auto &current = stack.back();
542436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
543436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
544436c6c9cSStella Laurenzo       if (!current.insertionPoint)
545436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
546436c6c9cSStella Laurenzo       if (!current.location)
547436c6c9cSStella Laurenzo         current.location = prev.location;
548436c6c9cSStella Laurenzo     }
549436c6c9cSStella Laurenzo   }
550436c6c9cSStella Laurenzo }
551436c6c9cSStella Laurenzo 
552436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
553436c6c9cSStella Laurenzo   if (!context)
554436c6c9cSStella Laurenzo     return nullptr;
555436c6c9cSStella Laurenzo   return py::cast<PyMlirContext *>(context);
556436c6c9cSStella Laurenzo }
557436c6c9cSStella Laurenzo 
558436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559436c6c9cSStella Laurenzo   if (!insertionPoint)
560436c6c9cSStella Laurenzo     return nullptr;
561436c6c9cSStella Laurenzo   return py::cast<PyInsertionPoint *>(insertionPoint);
562436c6c9cSStella Laurenzo }
563436c6c9cSStella Laurenzo 
564436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
565436c6c9cSStella Laurenzo   if (!location)
566436c6c9cSStella Laurenzo     return nullptr;
567436c6c9cSStella Laurenzo   return py::cast<PyLocation *>(location);
568436c6c9cSStella Laurenzo }
569436c6c9cSStella Laurenzo 
570436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
572436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
573436c6c9cSStella Laurenzo }
574436c6c9cSStella Laurenzo 
575436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
577436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
578436c6c9cSStella Laurenzo }
579436c6c9cSStella Laurenzo 
580436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
581436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
582436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
583436c6c9cSStella Laurenzo }
584436c6c9cSStella Laurenzo 
585436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586436c6c9cSStella Laurenzo   py::object contextObj = py::cast(context);
587436c6c9cSStella Laurenzo   push(FrameKind::Context, /*context=*/contextObj,
588436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
589436c6c9cSStella Laurenzo        /*location=*/py::object());
590436c6c9cSStella Laurenzo   return contextObj;
591436c6c9cSStella Laurenzo }
592436c6c9cSStella Laurenzo 
593436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
594436c6c9cSStella Laurenzo   auto &stack = getStack();
595436c6c9cSStella Laurenzo   if (stack.empty())
596436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597436c6c9cSStella Laurenzo   auto &tos = stack.back();
598436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600436c6c9cSStella Laurenzo   stack.pop_back();
601436c6c9cSStella Laurenzo }
602436c6c9cSStella Laurenzo 
603436c6c9cSStella Laurenzo py::object
604436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605436c6c9cSStella Laurenzo   py::object contextObj =
606436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607436c6c9cSStella Laurenzo   py::object insertionPointObj = py::cast(insertionPoint);
608436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
609436c6c9cSStella Laurenzo        /*context=*/contextObj,
610436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
611436c6c9cSStella Laurenzo        /*location=*/py::object());
612436c6c9cSStella Laurenzo   return insertionPointObj;
613436c6c9cSStella Laurenzo }
614436c6c9cSStella Laurenzo 
615436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616436c6c9cSStella Laurenzo   auto &stack = getStack();
617436c6c9cSStella Laurenzo   if (stack.empty())
618436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
619436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
620436c6c9cSStella Laurenzo   auto &tos = stack.back();
621436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
622436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
623436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
624436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
625436c6c9cSStella Laurenzo   stack.pop_back();
626436c6c9cSStella Laurenzo }
627436c6c9cSStella Laurenzo 
628436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629436c6c9cSStella Laurenzo   py::object contextObj = location.getContext().getObject();
630436c6c9cSStella Laurenzo   py::object locationObj = py::cast(location);
631436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
632436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
633436c6c9cSStella Laurenzo        /*location=*/locationObj);
634436c6c9cSStella Laurenzo   return locationObj;
635436c6c9cSStella Laurenzo }
636436c6c9cSStella Laurenzo 
637436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
638436c6c9cSStella Laurenzo   auto &stack = getStack();
639436c6c9cSStella Laurenzo   if (stack.empty())
640436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641436c6c9cSStella Laurenzo   auto &tos = stack.back();
642436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644436c6c9cSStella Laurenzo   stack.pop_back();
645436c6c9cSStella Laurenzo }
646436c6c9cSStella Laurenzo 
647436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
648436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects
649436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
650436c6c9cSStella Laurenzo 
651436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
652436c6c9cSStella Laurenzo                                          bool attrError) {
653f8479d9dSRiver Riddle   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
654f8479d9dSRiver Riddle                                                     {key.data(), key.size()});
655436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
656436c6c9cSStella Laurenzo     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
657436c6c9cSStella Laurenzo                      Twine("Dialect '") + key + "' not found");
658436c6c9cSStella Laurenzo   }
659436c6c9cSStella Laurenzo   return dialect;
660436c6c9cSStella Laurenzo }
661436c6c9cSStella Laurenzo 
662436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
663436c6c9cSStella Laurenzo // PyLocation
664436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
665436c6c9cSStella Laurenzo 
666436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() {
667436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
668436c6c9cSStella Laurenzo }
669436c6c9cSStella Laurenzo 
670436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) {
671436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
672436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
673436c6c9cSStella Laurenzo     throw py::error_already_set();
674436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
675436c6c9cSStella Laurenzo                     rawLoc);
676436c6c9cSStella Laurenzo }
677436c6c9cSStella Laurenzo 
678436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() {
679436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushLocation(*this);
680436c6c9cSStella Laurenzo }
681436c6c9cSStella Laurenzo 
682436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal,
683436c6c9cSStella Laurenzo                              py::object excTb) {
684436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
685436c6c9cSStella Laurenzo }
686436c6c9cSStella Laurenzo 
687436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
688436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
689436c6c9cSStella Laurenzo   if (!location) {
690436c6c9cSStella Laurenzo     throw SetPyError(
691436c6c9cSStella Laurenzo         PyExc_RuntimeError,
692436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
693436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
694436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
695436c6c9cSStella Laurenzo   }
696436c6c9cSStella Laurenzo   return *location;
697436c6c9cSStella Laurenzo }
698436c6c9cSStella Laurenzo 
699436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
700436c6c9cSStella Laurenzo // PyModule
701436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
702436c6c9cSStella Laurenzo 
703436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
704436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
705436c6c9cSStella Laurenzo 
706436c6c9cSStella Laurenzo PyModule::~PyModule() {
707436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
708436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
709436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
710436c6c9cSStella Laurenzo          "destroying module not in live map");
711436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
712436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
713436c6c9cSStella Laurenzo }
714436c6c9cSStella Laurenzo 
715436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
716436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
717436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
718436c6c9cSStella Laurenzo 
719436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
720436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
721436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
722436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
723436c6c9cSStella Laurenzo     // Create.
724436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
725436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
726436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
727436c6c9cSStella Laurenzo     // Just be explicit.
728436c6c9cSStella Laurenzo     py::object pyRef =
729436c6c9cSStella Laurenzo         py::cast(unownedModule, py::return_value_policy::take_ownership);
730436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
731436c6c9cSStella Laurenzo     liveModules[module.ptr] =
732436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
733436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
734436c6c9cSStella Laurenzo   }
735436c6c9cSStella Laurenzo   // Use existing.
736436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
737436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
738436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
739436c6c9cSStella Laurenzo }
740436c6c9cSStella Laurenzo 
741436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) {
742436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
743436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
744436c6c9cSStella Laurenzo     throw py::error_already_set();
745436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
746436c6c9cSStella Laurenzo }
747436c6c9cSStella Laurenzo 
748436c6c9cSStella Laurenzo py::object PyModule::getCapsule() {
749436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
750436c6c9cSStella Laurenzo }
751436c6c9cSStella Laurenzo 
752436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
753436c6c9cSStella Laurenzo // PyOperation
754436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
755436c6c9cSStella Laurenzo 
756436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
757436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
758436c6c9cSStella Laurenzo 
759436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
76049745f87SMike Urbach   // If the operation has already been invalidated there is nothing to do.
76149745f87SMike Urbach   if (!valid)
76249745f87SMike Urbach     return;
763436c6c9cSStella Laurenzo   auto &liveOperations = getContext()->liveOperations;
764436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 1 &&
765436c6c9cSStella Laurenzo          "destroying operation not in live map");
766436c6c9cSStella Laurenzo   liveOperations.erase(operation.ptr);
767436c6c9cSStella Laurenzo   if (!isAttached()) {
768436c6c9cSStella Laurenzo     mlirOperationDestroy(operation);
769436c6c9cSStella Laurenzo   }
770436c6c9cSStella Laurenzo }
771436c6c9cSStella Laurenzo 
772436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773436c6c9cSStella Laurenzo                                            MlirOperation operation,
774436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
775436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
776436c6c9cSStella Laurenzo   // Create.
777436c6c9cSStella Laurenzo   PyOperation *unownedOperation =
778436c6c9cSStella Laurenzo       new PyOperation(std::move(contextRef), operation);
779436c6c9cSStella Laurenzo   // Note that the default return value policy on cast is automatic_reference,
780436c6c9cSStella Laurenzo   // which does not take ownership (delete will not be called).
781436c6c9cSStella Laurenzo   // Just be explicit.
782436c6c9cSStella Laurenzo   py::object pyRef =
783436c6c9cSStella Laurenzo       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784436c6c9cSStella Laurenzo   unownedOperation->handle = pyRef;
785436c6c9cSStella Laurenzo   if (parentKeepAlive) {
786436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787436c6c9cSStella Laurenzo   }
788436c6c9cSStella Laurenzo   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789436c6c9cSStella Laurenzo   return PyOperationRef(unownedOperation, std::move(pyRef));
790436c6c9cSStella Laurenzo }
791436c6c9cSStella Laurenzo 
792436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793436c6c9cSStella Laurenzo                                          MlirOperation operation,
794436c6c9cSStella Laurenzo                                          py::object parentKeepAlive) {
795436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
796436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
797436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
798436c6c9cSStella Laurenzo     // Create.
799436c6c9cSStella Laurenzo     return createInstance(std::move(contextRef), operation,
800436c6c9cSStella Laurenzo                           std::move(parentKeepAlive));
801436c6c9cSStella Laurenzo   }
802436c6c9cSStella Laurenzo   // Use existing.
803436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
804436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
806436c6c9cSStella Laurenzo }
807436c6c9cSStella Laurenzo 
808436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809436c6c9cSStella Laurenzo                                            MlirOperation operation,
810436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
811436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
812436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
813436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
814436c6c9cSStella Laurenzo   (void)liveOperations;
815436c6c9cSStella Laurenzo 
816436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
817436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
818436c6c9cSStella Laurenzo   created->attached = false;
819436c6c9cSStella Laurenzo   return created;
820436c6c9cSStella Laurenzo }
821436c6c9cSStella Laurenzo 
822436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
823436c6c9cSStella Laurenzo   if (!valid) {
824436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825436c6c9cSStella Laurenzo   }
826436c6c9cSStella Laurenzo }
827436c6c9cSStella Laurenzo 
828436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary,
829436c6c9cSStella Laurenzo                             llvm::Optional<int64_t> largeElementsLimit,
830436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
831436c6c9cSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope) {
832436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
833436c6c9cSStella Laurenzo   operation.checkValid();
834436c6c9cSStella Laurenzo   if (fileObject.is_none())
835436c6c9cSStella Laurenzo     fileObject = py::module::import("sys").attr("stdout");
836436c6c9cSStella Laurenzo 
837436c6c9cSStella Laurenzo   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838436c6c9cSStella Laurenzo     fileObject.attr("write")("// Verification failed, printing generic form\n");
839436c6c9cSStella Laurenzo     printGenericOpForm = true;
840436c6c9cSStella Laurenzo   }
841436c6c9cSStella Laurenzo 
842436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843436c6c9cSStella Laurenzo   if (largeElementsLimit)
844436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845436c6c9cSStella Laurenzo   if (enableDebugInfo)
846436c6c9cSStella Laurenzo     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847436c6c9cSStella Laurenzo   if (printGenericOpForm)
848436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849436c6c9cSStella Laurenzo 
850436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
851436c6c9cSStella Laurenzo   py::gil_scoped_release();
852436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853436c6c9cSStella Laurenzo                               accum.getUserData());
854436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
855436c6c9cSStella Laurenzo }
856436c6c9cSStella Laurenzo 
857436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary,
858436c6c9cSStella Laurenzo                                    llvm::Optional<int64_t> largeElementsLimit,
859436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
860436c6c9cSStella Laurenzo                                    bool printGenericOpForm,
861436c6c9cSStella Laurenzo                                    bool useLocalScope) {
862436c6c9cSStella Laurenzo   py::object fileObject;
863436c6c9cSStella Laurenzo   if (binary) {
864436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("BytesIO")();
865436c6c9cSStella Laurenzo   } else {
866436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("StringIO")();
867436c6c9cSStella Laurenzo   }
868436c6c9cSStella Laurenzo   print(fileObject, /*binary=*/binary,
869436c6c9cSStella Laurenzo         /*largeElementsLimit=*/largeElementsLimit,
870436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
871436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
872436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
873436c6c9cSStella Laurenzo         /*useLocalScope=*/useLocalScope);
874436c6c9cSStella Laurenzo 
875436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
876436c6c9cSStella Laurenzo }
877436c6c9cSStella Laurenzo 
8781689dadeSJohn Demme llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
87949745f87SMike Urbach   checkValid();
880436c6c9cSStella Laurenzo   if (!isAttached())
881436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
882436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
883436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
8841689dadeSJohn Demme     return {};
885436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
886436c6c9cSStella Laurenzo }
887436c6c9cSStella Laurenzo 
888436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
88949745f87SMike Urbach   checkValid();
8901689dadeSJohn Demme   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
891436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
892436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
8931689dadeSJohn Demme   assert(parentOperation && "Operation has no parent");
8941689dadeSJohn Demme   return PyBlock{std::move(*parentOperation), block};
895436c6c9cSStella Laurenzo }
896436c6c9cSStella Laurenzo 
8970126e906SJohn Demme py::object PyOperation::getCapsule() {
89849745f87SMike Urbach   checkValid();
8990126e906SJohn Demme   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
9000126e906SJohn Demme }
9010126e906SJohn Demme 
9020126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) {
9030126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
9040126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
9050126e906SJohn Demme     throw py::error_already_set();
9060126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
9070126e906SJohn Demme   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
9080126e906SJohn Demme       .releaseObject();
9090126e906SJohn Demme }
9100126e906SJohn Demme 
911436c6c9cSStella Laurenzo py::object PyOperation::create(
912436c6c9cSStella Laurenzo     std::string name, llvm::Optional<std::vector<PyType *>> results,
913436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyValue *>> operands,
914436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
915436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
916436c6c9cSStella Laurenzo     DefaultingPyLocation location, py::object maybeIp) {
917436c6c9cSStella Laurenzo   llvm::SmallVector<MlirValue, 4> mlirOperands;
918436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
919436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
920436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
921436c6c9cSStella Laurenzo 
922436c6c9cSStella Laurenzo   // General parameter validation.
923436c6c9cSStella Laurenzo   if (regions < 0)
924436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
925436c6c9cSStella Laurenzo 
926436c6c9cSStella Laurenzo   // Unpack/validate operands.
927436c6c9cSStella Laurenzo   if (operands) {
928436c6c9cSStella Laurenzo     mlirOperands.reserve(operands->size());
929436c6c9cSStella Laurenzo     for (PyValue *operand : *operands) {
930436c6c9cSStella Laurenzo       if (!operand)
931436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
932436c6c9cSStella Laurenzo       mlirOperands.push_back(operand->get());
933436c6c9cSStella Laurenzo     }
934436c6c9cSStella Laurenzo   }
935436c6c9cSStella Laurenzo 
936436c6c9cSStella Laurenzo   // Unpack/validate results.
937436c6c9cSStella Laurenzo   if (results) {
938436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
939436c6c9cSStella Laurenzo     for (PyType *result : *results) {
940436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
941436c6c9cSStella Laurenzo       if (!result)
942436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "result type cannot be None");
943436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
944436c6c9cSStella Laurenzo     }
945436c6c9cSStella Laurenzo   }
946436c6c9cSStella Laurenzo   // Unpack/validate attributes.
947436c6c9cSStella Laurenzo   if (attributes) {
948436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
949436c6c9cSStella Laurenzo     for (auto &it : *attributes) {
950436c6c9cSStella Laurenzo       std::string key;
951436c6c9cSStella Laurenzo       try {
952436c6c9cSStella Laurenzo         key = it.first.cast<std::string>();
953436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
954436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
955436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
956436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
957436c6c9cSStella Laurenzo         throw py::cast_error(msg);
958436c6c9cSStella Laurenzo       }
959436c6c9cSStella Laurenzo       try {
960436c6c9cSStella Laurenzo         auto &attribute = it.second.cast<PyAttribute &>();
961436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
962436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
963436c6c9cSStella Laurenzo       } catch (py::reference_cast_error &) {
964436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
965436c6c9cSStella Laurenzo         std::string msg =
966436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
967436c6c9cSStella Laurenzo             "\" when attempting to create the operation \"" + name + "\"";
968436c6c9cSStella Laurenzo         throw py::cast_error(msg);
969436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
970436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute value for the key \"" + key +
971436c6c9cSStella Laurenzo                           "\" when attempting to create the operation \"" +
972436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
973436c6c9cSStella Laurenzo         throw py::cast_error(msg);
974436c6c9cSStella Laurenzo       }
975436c6c9cSStella Laurenzo     }
976436c6c9cSStella Laurenzo   }
977436c6c9cSStella Laurenzo   // Unpack/validate successors.
978436c6c9cSStella Laurenzo   if (successors) {
979436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
980436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
981436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
982436c6c9cSStella Laurenzo       if (!successor)
983436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
984436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
985436c6c9cSStella Laurenzo     }
986436c6c9cSStella Laurenzo   }
987436c6c9cSStella Laurenzo 
988436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
989436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
990436c6c9cSStella Laurenzo   MlirOperationState state =
991436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
992436c6c9cSStella Laurenzo   if (!mlirOperands.empty())
993436c6c9cSStella Laurenzo     mlirOperationStateAddOperands(&state, mlirOperands.size(),
994436c6c9cSStella Laurenzo                                   mlirOperands.data());
995436c6c9cSStella Laurenzo   if (!mlirResults.empty())
996436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
997436c6c9cSStella Laurenzo                                  mlirResults.data());
998436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
999436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
1000436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
1001436c6c9cSStella Laurenzo     // on.
1002436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1003436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
1004436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
1005436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1006436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1007436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
1008436c6c9cSStella Laurenzo           it.second));
1009436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1010436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1011436c6c9cSStella Laurenzo   }
1012436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
1013436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1014436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
1015436c6c9cSStella Laurenzo   if (regions) {
1016436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1017436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
1018436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
1019436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
1020436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1021436c6c9cSStella Laurenzo                                       mlirRegions.data());
1022436c6c9cSStella Laurenzo   }
1023436c6c9cSStella Laurenzo 
1024436c6c9cSStella Laurenzo   // Construct the operation.
1025436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1026436c6c9cSStella Laurenzo   PyOperationRef created =
1027436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1028436c6c9cSStella Laurenzo 
1029436c6c9cSStella Laurenzo   // InsertPoint active?
1030436c6c9cSStella Laurenzo   if (!maybeIp.is(py::cast(false))) {
1031436c6c9cSStella Laurenzo     PyInsertionPoint *ip;
1032436c6c9cSStella Laurenzo     if (maybeIp.is_none()) {
1033436c6c9cSStella Laurenzo       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1034436c6c9cSStella Laurenzo     } else {
1035436c6c9cSStella Laurenzo       ip = py::cast<PyInsertionPoint *>(maybeIp);
1036436c6c9cSStella Laurenzo     }
1037436c6c9cSStella Laurenzo     if (ip)
1038436c6c9cSStella Laurenzo       ip->insert(*created.get());
1039436c6c9cSStella Laurenzo   }
1040436c6c9cSStella Laurenzo 
1041436c6c9cSStella Laurenzo   return created->createOpView();
1042436c6c9cSStella Laurenzo }
1043436c6c9cSStella Laurenzo 
1044436c6c9cSStella Laurenzo py::object PyOperation::createOpView() {
104549745f87SMike Urbach   checkValid();
1046436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1047436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1048436c6c9cSStella Laurenzo   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1049436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1050436c6c9cSStella Laurenzo   if (opViewClass)
1051436c6c9cSStella Laurenzo     return (*opViewClass)(getRef().getObject());
1052436c6c9cSStella Laurenzo   return py::cast(PyOpView(getRef().getObject()));
1053436c6c9cSStella Laurenzo }
1054436c6c9cSStella Laurenzo 
105549745f87SMike Urbach void PyOperation::erase() {
105649745f87SMike Urbach   checkValid();
105749745f87SMike Urbach   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
105849745f87SMike Urbach   // Python reference to a child operation is live. All children should also
105949745f87SMike Urbach   // have their `valid` bit set to false.
106049745f87SMike Urbach   auto &liveOperations = getContext()->liveOperations;
106149745f87SMike Urbach   if (liveOperations.count(operation.ptr))
106249745f87SMike Urbach     liveOperations.erase(operation.ptr);
106349745f87SMike Urbach   mlirOperationDestroy(operation);
106449745f87SMike Urbach   valid = false;
106549745f87SMike Urbach }
106649745f87SMike Urbach 
1067436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1068436c6c9cSStella Laurenzo // PyOpView
1069436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1070436c6c9cSStella Laurenzo 
1071436c6c9cSStella Laurenzo py::object
1072436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1073436c6c9cSStella Laurenzo                        py::list operandList,
1074436c6c9cSStella Laurenzo                        llvm::Optional<py::dict> attributes,
1075436c6c9cSStella Laurenzo                        llvm::Optional<std::vector<PyBlock *>> successors,
1076436c6c9cSStella Laurenzo                        llvm::Optional<int> regions,
1077436c6c9cSStella Laurenzo                        DefaultingPyLocation location, py::object maybeIp) {
1078436c6c9cSStella Laurenzo   PyMlirContextRef context = location->getContext();
1079436c6c9cSStella Laurenzo   // Class level operation construction metadata.
1080436c6c9cSStella Laurenzo   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1081436c6c9cSStella Laurenzo   // Operand and result segment specs are either none, which does no
1082436c6c9cSStella Laurenzo   // variadic unpacking, or a list of ints with segment sizes, where each
1083436c6c9cSStella Laurenzo   // element is either a positive number (typically 1 for a scalar) or -1 to
1084436c6c9cSStella Laurenzo   // indicate that it is derived from the length of the same-indexed operand
1085436c6c9cSStella Laurenzo   // or result (implying that it is a list at that position).
1086436c6c9cSStella Laurenzo   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1087436c6c9cSStella Laurenzo   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1088436c6c9cSStella Laurenzo 
10898d05a288SStella Laurenzo   std::vector<uint32_t> operandSegmentLengths;
10908d05a288SStella Laurenzo   std::vector<uint32_t> resultSegmentLengths;
1091436c6c9cSStella Laurenzo 
1092436c6c9cSStella Laurenzo   // Validate/determine region count.
1093436c6c9cSStella Laurenzo   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1094436c6c9cSStella Laurenzo   int opMinRegionCount = std::get<0>(opRegionSpec);
1095436c6c9cSStella Laurenzo   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1096436c6c9cSStella Laurenzo   if (!regions) {
1097436c6c9cSStella Laurenzo     regions = opMinRegionCount;
1098436c6c9cSStella Laurenzo   }
1099436c6c9cSStella Laurenzo   if (*regions < opMinRegionCount) {
1100436c6c9cSStella Laurenzo     throw py::value_error(
1101436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1102436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1103436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1104436c6c9cSStella Laurenzo             .str());
1105436c6c9cSStella Laurenzo   }
1106436c6c9cSStella Laurenzo   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1107436c6c9cSStella Laurenzo     throw py::value_error(
1108436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1109436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1110436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1111436c6c9cSStella Laurenzo             .str());
1112436c6c9cSStella Laurenzo   }
1113436c6c9cSStella Laurenzo 
1114436c6c9cSStella Laurenzo   // Unpack results.
1115436c6c9cSStella Laurenzo   std::vector<PyType *> resultTypes;
1116436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1117436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1118436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1119436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(resultTypeList)) {
1120436c6c9cSStella Laurenzo       try {
1121436c6c9cSStella Laurenzo         resultTypes.push_back(py::cast<PyType *>(it.value()));
1122436c6c9cSStella Laurenzo         if (!resultTypes.back())
1123436c6c9cSStella Laurenzo           throw py::cast_error();
1124436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1125436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Result ") +
1126436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1127436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1128436c6c9cSStella Laurenzo                                   .str());
1129436c6c9cSStella Laurenzo       }
1130436c6c9cSStella Laurenzo     }
1131436c6c9cSStella Laurenzo   } else {
1132436c6c9cSStella Laurenzo     // Sized result unpacking.
1133436c6c9cSStella Laurenzo     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1134436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1135436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1136436c6c9cSStella Laurenzo                              "\" requires " +
1137436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1138436c6c9cSStella Laurenzo                              "result segments but was provided " +
1139436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1140436c6c9cSStella Laurenzo                                 .str());
1141436c6c9cSStella Laurenzo     }
1142436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1143436c6c9cSStella Laurenzo     for (auto it :
1144436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1145436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1146436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1147436c6c9cSStella Laurenzo         // Unpack unary element.
1148436c6c9cSStella Laurenzo         try {
1149436c6c9cSStella Laurenzo           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1150436c6c9cSStella Laurenzo           if (resultType) {
1151436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1152436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1153436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1154436c6c9cSStella Laurenzo             // Allowed to be optional.
1155436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1156436c6c9cSStella Laurenzo           } else {
1157436c6c9cSStella Laurenzo             throw py::cast_error("was None and result is not optional");
1158436c6c9cSStella Laurenzo           }
1159436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1160436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1161436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1162436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1163436c6c9cSStella Laurenzo                                  ")")
1164436c6c9cSStella Laurenzo                                     .str());
1165436c6c9cSStella Laurenzo         }
1166436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1167436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1168436c6c9cSStella Laurenzo         try {
1169436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1170436c6c9cSStella Laurenzo             // Treat it as an empty list.
1171436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1172436c6c9cSStella Laurenzo           } else {
1173436c6c9cSStella Laurenzo             // Unpack the list.
1174436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1175436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1176436c6c9cSStella Laurenzo               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1177436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1178436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1179436c6c9cSStella Laurenzo               }
1180436c6c9cSStella Laurenzo             }
1181436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(segment.size());
1182436c6c9cSStella Laurenzo           }
1183436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1184436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1185436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1186436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1187436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1188436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1189436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1190436c6c9cSStella Laurenzo                                  err.what() + ")")
1191436c6c9cSStella Laurenzo                                     .str());
1192436c6c9cSStella Laurenzo         }
1193436c6c9cSStella Laurenzo       } else {
1194436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1195436c6c9cSStella Laurenzo       }
1196436c6c9cSStella Laurenzo     }
1197436c6c9cSStella Laurenzo   }
1198436c6c9cSStella Laurenzo 
1199436c6c9cSStella Laurenzo   // Unpack operands.
1200436c6c9cSStella Laurenzo   std::vector<PyValue *> operands;
1201436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1202436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1203436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1204436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(operandList)) {
1205436c6c9cSStella Laurenzo       try {
1206436c6c9cSStella Laurenzo         operands.push_back(py::cast<PyValue *>(it.value()));
1207436c6c9cSStella Laurenzo         if (!operands.back())
1208436c6c9cSStella Laurenzo           throw py::cast_error();
1209436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1210436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Operand ") +
1211436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1212436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1213436c6c9cSStella Laurenzo                                   .str());
1214436c6c9cSStella Laurenzo       }
1215436c6c9cSStella Laurenzo     }
1216436c6c9cSStella Laurenzo   } else {
1217436c6c9cSStella Laurenzo     // Sized operand unpacking.
1218436c6c9cSStella Laurenzo     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1219436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1220436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1221436c6c9cSStella Laurenzo                              "\" requires " +
1222436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1223436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1224436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1225436c6c9cSStella Laurenzo                                 .str());
1226436c6c9cSStella Laurenzo     }
1227436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1228436c6c9cSStella Laurenzo     for (auto it :
1229436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1230436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1231436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1232436c6c9cSStella Laurenzo         // Unpack unary element.
1233436c6c9cSStella Laurenzo         try {
1234436c6c9cSStella Laurenzo           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1235436c6c9cSStella Laurenzo           if (operandValue) {
1236436c6c9cSStella Laurenzo             operands.push_back(operandValue);
1237436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(1);
1238436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1239436c6c9cSStella Laurenzo             // Allowed to be optional.
1240436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1241436c6c9cSStella Laurenzo           } else {
1242436c6c9cSStella Laurenzo             throw py::cast_error("was None and operand is not optional");
1243436c6c9cSStella Laurenzo           }
1244436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1245436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1246436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1247436c6c9cSStella Laurenzo                                  name + "\" must be a Value (" + err.what() +
1248436c6c9cSStella Laurenzo                                  ")")
1249436c6c9cSStella Laurenzo                                     .str());
1250436c6c9cSStella Laurenzo         }
1251436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1252436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1253436c6c9cSStella Laurenzo         try {
1254436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1255436c6c9cSStella Laurenzo             // Treat it as an empty list.
1256436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1257436c6c9cSStella Laurenzo           } else {
1258436c6c9cSStella Laurenzo             // Unpack the list.
1259436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1260436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1261436c6c9cSStella Laurenzo               operands.push_back(py::cast<PyValue *>(segmentItem));
1262436c6c9cSStella Laurenzo               if (!operands.back()) {
1263436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1264436c6c9cSStella Laurenzo               }
1265436c6c9cSStella Laurenzo             }
1266436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(segment.size());
1267436c6c9cSStella Laurenzo           }
1268436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1269436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1270436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1271436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1272436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1273436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1274436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
1275436c6c9cSStella Laurenzo                                  err.what() + ")")
1276436c6c9cSStella Laurenzo                                     .str());
1277436c6c9cSStella Laurenzo         }
1278436c6c9cSStella Laurenzo       } else {
1279436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1280436c6c9cSStella Laurenzo       }
1281436c6c9cSStella Laurenzo     }
1282436c6c9cSStella Laurenzo   }
1283436c6c9cSStella Laurenzo 
1284436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
1285436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1286436c6c9cSStella Laurenzo     // Dup.
1287436c6c9cSStella Laurenzo     if (attributes) {
1288436c6c9cSStella Laurenzo       attributes = py::dict(*attributes);
1289436c6c9cSStella Laurenzo     } else {
1290436c6c9cSStella Laurenzo       attributes = py::dict();
1291436c6c9cSStella Laurenzo     }
1292436c6c9cSStella Laurenzo     if (attributes->contains("result_segment_sizes") ||
1293436c6c9cSStella Laurenzo         attributes->contains("operand_segment_sizes")) {
1294436c6c9cSStella Laurenzo       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1295436c6c9cSStella Laurenzo                             "'operand_segment_sizes' attribute is unsupported. "
1296436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
1297436c6c9cSStella Laurenzo     }
1298436c6c9cSStella Laurenzo 
1299436c6c9cSStella Laurenzo     // Add result_segment_sizes attribute.
1300436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
1301436c6c9cSStella Laurenzo       int64_t size = resultSegmentLengths.size();
13028d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
13038d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1304436c6c9cSStella Laurenzo           resultSegmentLengths.size(), resultSegmentLengths.data());
1305436c6c9cSStella Laurenzo       (*attributes)["result_segment_sizes"] =
1306436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1307436c6c9cSStella Laurenzo     }
1308436c6c9cSStella Laurenzo 
1309436c6c9cSStella Laurenzo     // Add operand_segment_sizes attribute.
1310436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
1311436c6c9cSStella Laurenzo       int64_t size = operandSegmentLengths.size();
13128d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
13138d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1314436c6c9cSStella Laurenzo           operandSegmentLengths.size(), operandSegmentLengths.data());
1315436c6c9cSStella Laurenzo       (*attributes)["operand_segment_sizes"] =
1316436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1317436c6c9cSStella Laurenzo     }
1318436c6c9cSStella Laurenzo   }
1319436c6c9cSStella Laurenzo 
1320436c6c9cSStella Laurenzo   // Delegate to create.
1321436c6c9cSStella Laurenzo   return PyOperation::create(std::move(name),
1322436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
1323436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
1324436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
1325436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
1326436c6c9cSStella Laurenzo                              /*regions=*/*regions, location, maybeIp);
1327436c6c9cSStella Laurenzo }
1328436c6c9cSStella Laurenzo 
1329436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject)
1330436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
1331436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
1332436c6c9cSStella Laurenzo     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1333436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
1334436c6c9cSStella Laurenzo 
1335436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) {
1336436c6c9cSStella Laurenzo   // This is... a little gross. The typical pattern is to have a pure python
1337436c6c9cSStella Laurenzo   // class that extends OpView like:
1338436c6c9cSStella Laurenzo   //   class AddFOp(_cext.ir.OpView):
1339436c6c9cSStella Laurenzo   //     def __init__(self, loc, lhs, rhs):
1340436c6c9cSStella Laurenzo   //       operation = loc.context.create_operation(
1341436c6c9cSStella Laurenzo   //           "addf", lhs, rhs, results=[lhs.type])
1342436c6c9cSStella Laurenzo   //       super().__init__(operation)
1343436c6c9cSStella Laurenzo   //
1344436c6c9cSStella Laurenzo   // I.e. The goal of the user facing type is to provide a nice constructor
1345436c6c9cSStella Laurenzo   // that has complete freedom for the op under construction. This is at odds
1346436c6c9cSStella Laurenzo   // with our other desire to sometimes create this object by just passing an
1347436c6c9cSStella Laurenzo   // operation (to initialize the base class). We could do *arg and **kwargs
1348436c6c9cSStella Laurenzo   // munging to try to make it work, but instead, we synthesize a new class
1349436c6c9cSStella Laurenzo   // on the fly which extends this user class (AddFOp in this example) and
1350436c6c9cSStella Laurenzo   // *give it* the base class's __init__ method, thus bypassing the
1351436c6c9cSStella Laurenzo   // intermediate subclass's __init__ method entirely. While slightly,
1352436c6c9cSStella Laurenzo   // underhanded, this is safe/legal because the type hierarchy has not changed
1353436c6c9cSStella Laurenzo   // (we just added a new leaf) and we aren't mucking around with __new__.
1354436c6c9cSStella Laurenzo   // Typically, this new class will be stored on the original as "_Raw" and will
1355436c6c9cSStella Laurenzo   // be used for casts and other things that need a variant of the class that
1356436c6c9cSStella Laurenzo   // is initialized purely from an operation.
1357436c6c9cSStella Laurenzo   py::object parentMetaclass =
1358436c6c9cSStella Laurenzo       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1359436c6c9cSStella Laurenzo   py::dict attributes;
1360436c6c9cSStella Laurenzo   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1361436c6c9cSStella Laurenzo   // now.
1362436c6c9cSStella Laurenzo   //   auto opViewType = py::type::of<PyOpView>();
1363436c6c9cSStella Laurenzo   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1364436c6c9cSStella Laurenzo   attributes["__init__"] = opViewType.attr("__init__");
1365436c6c9cSStella Laurenzo   py::str origName = userClass.attr("__name__");
1366436c6c9cSStella Laurenzo   py::str newName = py::str("_") + origName;
1367436c6c9cSStella Laurenzo   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1368436c6c9cSStella Laurenzo }
1369436c6c9cSStella Laurenzo 
1370436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1371436c6c9cSStella Laurenzo // PyInsertionPoint.
1372436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1373436c6c9cSStella Laurenzo 
1374436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1375436c6c9cSStella Laurenzo 
1376436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1377436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
1378436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
1379436c6c9cSStella Laurenzo 
1380436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1381436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
1382436c6c9cSStella Laurenzo   if (operation.isAttached())
1383436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError,
1384436c6c9cSStella Laurenzo                      "Attempt to insert operation that is already attached");
1385436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
1386436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
1387436c6c9cSStella Laurenzo   if (refOperation) {
1388436c6c9cSStella Laurenzo     // Insert before operation.
1389436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
1390436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
1391436c6c9cSStella Laurenzo   } else {
1392436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
1393436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
1394436c6c9cSStella Laurenzo     // failures later).
1395436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1396436c6c9cSStella Laurenzo       throw py::index_error("Cannot insert operation at the end of a block "
1397436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
1398436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
1399436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
1400436c6c9cSStella Laurenzo     }
1401436c6c9cSStella Laurenzo   }
1402436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1403436c6c9cSStella Laurenzo   operation.setAttached();
1404436c6c9cSStella Laurenzo }
1405436c6c9cSStella Laurenzo 
1406436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1407436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1408436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
1409436c6c9cSStella Laurenzo     // Just insert at end.
1410436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
1411436c6c9cSStella Laurenzo   }
1412436c6c9cSStella Laurenzo 
1413436c6c9cSStella Laurenzo   // Insert before first op.
1414436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
1415436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
1416436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
1417436c6c9cSStella Laurenzo }
1418436c6c9cSStella Laurenzo 
1419436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1420436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1421436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
1422436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1423436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1424436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
1425436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1426436c6c9cSStella Laurenzo }
1427436c6c9cSStella Laurenzo 
1428436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() {
1429436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushInsertionPoint(*this);
1430436c6c9cSStella Laurenzo }
1431436c6c9cSStella Laurenzo 
1432436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType,
1433436c6c9cSStella Laurenzo                                    pybind11::object excVal,
1434436c6c9cSStella Laurenzo                                    pybind11::object excTb) {
1435436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
1436436c6c9cSStella Laurenzo }
1437436c6c9cSStella Laurenzo 
1438436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1439436c6c9cSStella Laurenzo // PyAttribute.
1440436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1441436c6c9cSStella Laurenzo 
1442436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) {
1443436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
1444436c6c9cSStella Laurenzo }
1445436c6c9cSStella Laurenzo 
1446436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() {
1447436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1448436c6c9cSStella Laurenzo }
1449436c6c9cSStella Laurenzo 
1450436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1451436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1452436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
1453436c6c9cSStella Laurenzo     throw py::error_already_set();
1454436c6c9cSStella Laurenzo   return PyAttribute(
1455436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1456436c6c9cSStella Laurenzo }
1457436c6c9cSStella Laurenzo 
1458436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1459436c6c9cSStella Laurenzo // PyNamedAttribute.
1460436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1461436c6c9cSStella Laurenzo 
1462436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1463436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
1464436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
1465436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
1466436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
1467436c6c9cSStella Laurenzo       attr);
1468436c6c9cSStella Laurenzo }
1469436c6c9cSStella Laurenzo 
1470436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1471436c6c9cSStella Laurenzo // PyType.
1472436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1473436c6c9cSStella Laurenzo 
1474436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) {
1475436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
1476436c6c9cSStella Laurenzo }
1477436c6c9cSStella Laurenzo 
1478436c6c9cSStella Laurenzo py::object PyType::getCapsule() {
1479436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1480436c6c9cSStella Laurenzo }
1481436c6c9cSStella Laurenzo 
1482436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) {
1483436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1484436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
1485436c6c9cSStella Laurenzo     throw py::error_already_set();
1486436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1487436c6c9cSStella Laurenzo                 rawType);
1488436c6c9cSStella Laurenzo }
1489436c6c9cSStella Laurenzo 
1490436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1491436c6c9cSStella Laurenzo // PyValue and subclases.
1492436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1493436c6c9cSStella Laurenzo 
14943f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() {
14953f3d1c90SMike Urbach   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
14963f3d1c90SMike Urbach }
14973f3d1c90SMike Urbach 
14983f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) {
14993f3d1c90SMike Urbach   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
15003f3d1c90SMike Urbach   if (mlirValueIsNull(value))
15013f3d1c90SMike Urbach     throw py::error_already_set();
15023f3d1c90SMike Urbach   MlirOperation owner;
15033f3d1c90SMike Urbach   if (mlirValueIsAOpResult(value))
15043f3d1c90SMike Urbach     owner = mlirOpResultGetOwner(value);
15053f3d1c90SMike Urbach   if (mlirValueIsABlockArgument(value))
15063f3d1c90SMike Urbach     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
15073f3d1c90SMike Urbach   if (mlirOperationIsNull(owner))
15083f3d1c90SMike Urbach     throw py::error_already_set();
15093f3d1c90SMike Urbach   MlirContext ctx = mlirOperationGetContext(owner);
15103f3d1c90SMike Urbach   PyOperationRef ownerRef =
15113f3d1c90SMike Urbach       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
15123f3d1c90SMike Urbach   return PyValue(ownerRef, value);
15133f3d1c90SMike Urbach }
15143f3d1c90SMike Urbach 
1515436c6c9cSStella Laurenzo namespace {
1516436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be
1517436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed
1518436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes.
1519436c6c9cSStella Laurenzo template <typename DerivedTy>
1520436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue {
1521436c6c9cSStella Laurenzo public:
1522436c6c9cSStella Laurenzo   // Derived classes must define statics for:
1523436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
1524436c6c9cSStella Laurenzo   //   const char *pyClassName
1525436c6c9cSStella Laurenzo   // and redefine bindDerived.
1526436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, PyValue>;
1527436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirValue);
1528436c6c9cSStella Laurenzo 
1529436c6c9cSStella Laurenzo   PyConcreteValue() = default;
1530436c6c9cSStella Laurenzo   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1531436c6c9cSStella Laurenzo       : PyValue(operationRef, value) {}
1532436c6c9cSStella Laurenzo   PyConcreteValue(PyValue &orig)
1533436c6c9cSStella Laurenzo       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1534436c6c9cSStella Laurenzo 
1535436c6c9cSStella Laurenzo   /// Attempts to cast the original value to the derived type and throws on
1536436c6c9cSStella Laurenzo   /// type mismatches.
1537436c6c9cSStella Laurenzo   static MlirValue castFrom(PyValue &orig) {
1538436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig.get())) {
1539436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1540436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1541436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
1542436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
1543436c6c9cSStella Laurenzo     }
1544436c6c9cSStella Laurenzo     return orig.get();
1545436c6c9cSStella Laurenzo   }
1546436c6c9cSStella Laurenzo 
1547436c6c9cSStella Laurenzo   /// Binds the Python module objects to functions of this class.
1548436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1549f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1550436c6c9cSStella Laurenzo     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1551436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
1552436c6c9cSStella Laurenzo   }
1553436c6c9cSStella Laurenzo 
1554436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
1555436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
1556436c6c9cSStella Laurenzo };
1557436c6c9cSStella Laurenzo 
1558436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
1559436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1560436c6c9cSStella Laurenzo public:
1561436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1562436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
1563436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1564436c6c9cSStella Laurenzo 
1565436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1566436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1567436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
1568436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
1569436c6c9cSStella Laurenzo     });
1570436c6c9cSStella Laurenzo     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1571436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
1572436c6c9cSStella Laurenzo     });
1573436c6c9cSStella Laurenzo     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1574436c6c9cSStella Laurenzo       return mlirBlockArgumentSetType(self.get(), type);
1575436c6c9cSStella Laurenzo     });
1576436c6c9cSStella Laurenzo   }
1577436c6c9cSStella Laurenzo };
1578436c6c9cSStella Laurenzo 
1579436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult.
1580436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> {
1581436c6c9cSStella Laurenzo public:
1582436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1583436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResult";
1584436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1585436c6c9cSStella Laurenzo 
1586436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1587436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyOpResult &self) {
1588436c6c9cSStella Laurenzo       assert(
1589436c6c9cSStella Laurenzo           mlirOperationEqual(self.getParentOperation()->get(),
1590436c6c9cSStella Laurenzo                              mlirOpResultGetOwner(self.get())) &&
1591436c6c9cSStella Laurenzo           "expected the owner of the value in Python to match that in the IR");
15926ff74f96SMike Urbach       return self.getParentOperation().getObject();
1593436c6c9cSStella Laurenzo     });
1594436c6c9cSStella Laurenzo     c.def_property_readonly("result_number", [](PyOpResult &self) {
1595436c6c9cSStella Laurenzo       return mlirOpResultGetResultNumber(self.get());
1596436c6c9cSStella Laurenzo     });
1597436c6c9cSStella Laurenzo   }
1598436c6c9cSStella Laurenzo };
1599436c6c9cSStella Laurenzo 
1600ed9e52f3SAlex Zinenko /// Returns the list of types of the values held by container.
1601ed9e52f3SAlex Zinenko template <typename Container>
1602ed9e52f3SAlex Zinenko static std::vector<PyType> getValueTypes(Container &container,
1603ed9e52f3SAlex Zinenko                                          PyMlirContextRef &context) {
1604ed9e52f3SAlex Zinenko   std::vector<PyType> result;
1605ed9e52f3SAlex Zinenko   result.reserve(container.getNumElements());
1606ed9e52f3SAlex Zinenko   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1607ed9e52f3SAlex Zinenko     result.push_back(
1608ed9e52f3SAlex Zinenko         PyType(context, mlirValueGetType(container.getElement(i).get())));
1609ed9e52f3SAlex Zinenko   }
1610ed9e52f3SAlex Zinenko   return result;
1611ed9e52f3SAlex Zinenko }
1612ed9e52f3SAlex Zinenko 
1613436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
1614436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
1615436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
1616436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
1617afeda4b9SAlex Zinenko class PyBlockArgumentList
1618afeda4b9SAlex Zinenko     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1619436c6c9cSStella Laurenzo public:
1620afeda4b9SAlex Zinenko   static constexpr const char *pyClassName = "BlockArgumentList";
1621436c6c9cSStella Laurenzo 
1622afeda4b9SAlex Zinenko   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1623afeda4b9SAlex Zinenko                       intptr_t startIndex = 0, intptr_t length = -1,
1624afeda4b9SAlex Zinenko                       intptr_t step = 1)
1625afeda4b9SAlex Zinenko       : Sliceable(startIndex,
1626afeda4b9SAlex Zinenko                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1627afeda4b9SAlex Zinenko                   step),
1628afeda4b9SAlex Zinenko         operation(std::move(operation)), block(block) {}
1629afeda4b9SAlex Zinenko 
1630afeda4b9SAlex Zinenko   /// Returns the number of arguments in the list.
1631afeda4b9SAlex Zinenko   intptr_t getNumElements() {
1632436c6c9cSStella Laurenzo     operation->checkValid();
1633436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
1634436c6c9cSStella Laurenzo   }
1635436c6c9cSStella Laurenzo 
1636afeda4b9SAlex Zinenko   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1637afeda4b9SAlex Zinenko   PyBlockArgument getElement(intptr_t pos) {
1638afeda4b9SAlex Zinenko     MlirValue argument = mlirBlockGetArgument(block, pos);
1639afeda4b9SAlex Zinenko     return PyBlockArgument(operation, argument);
1640436c6c9cSStella Laurenzo   }
1641436c6c9cSStella Laurenzo 
1642afeda4b9SAlex Zinenko   /// Returns a sublist of this list.
1643afeda4b9SAlex Zinenko   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1644afeda4b9SAlex Zinenko                             intptr_t step) {
1645afeda4b9SAlex Zinenko     return PyBlockArgumentList(operation, block, startIndex, length, step);
1646436c6c9cSStella Laurenzo   }
1647436c6c9cSStella Laurenzo 
1648ed9e52f3SAlex Zinenko   static void bindDerived(ClassTy &c) {
1649ed9e52f3SAlex Zinenko     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1650ed9e52f3SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
1651ed9e52f3SAlex Zinenko     });
1652ed9e52f3SAlex Zinenko   }
1653ed9e52f3SAlex Zinenko 
1654436c6c9cSStella Laurenzo private:
1655436c6c9cSStella Laurenzo   PyOperationRef operation;
1656436c6c9cSStella Laurenzo   MlirBlock block;
1657436c6c9cSStella Laurenzo };
1658436c6c9cSStella Laurenzo 
1659436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
1660436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1661436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1662436c6c9cSStella Laurenzo /// operation.
1663436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1664436c6c9cSStella Laurenzo public:
1665436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
1666436c6c9cSStella Laurenzo 
1667436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1668436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
1669436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1670436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1671436c6c9cSStella Laurenzo                                : length,
1672436c6c9cSStella Laurenzo                   step),
1673436c6c9cSStella Laurenzo         operation(operation) {}
1674436c6c9cSStella Laurenzo 
1675436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1676436c6c9cSStella Laurenzo     operation->checkValid();
1677436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
1678436c6c9cSStella Laurenzo   }
1679436c6c9cSStella Laurenzo 
1680436c6c9cSStella Laurenzo   PyValue getElement(intptr_t pos) {
16815664c5e2SJohn Demme     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
16825664c5e2SJohn Demme     MlirOperation owner;
16835664c5e2SJohn Demme     if (mlirValueIsAOpResult(operand))
16845664c5e2SJohn Demme       owner = mlirOpResultGetOwner(operand);
16855664c5e2SJohn Demme     else if (mlirValueIsABlockArgument(operand))
16865664c5e2SJohn Demme       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
16875664c5e2SJohn Demme     else
16885664c5e2SJohn Demme       assert(false && "Value must be an block arg or op result.");
16895664c5e2SJohn Demme     PyOperationRef pyOwner =
16905664c5e2SJohn Demme         PyOperation::forOperation(operation->getContext(), owner);
16915664c5e2SJohn Demme     return PyValue(pyOwner, operand);
1692436c6c9cSStella Laurenzo   }
1693436c6c9cSStella Laurenzo 
1694436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1695436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
1696436c6c9cSStella Laurenzo   }
1697436c6c9cSStella Laurenzo 
169863d16d06SMike Urbach   void dunderSetItem(intptr_t index, PyValue value) {
169963d16d06SMike Urbach     index = wrapIndex(index);
170063d16d06SMike Urbach     mlirOperationSetOperand(operation->get(), index, value.get());
170163d16d06SMike Urbach   }
170263d16d06SMike Urbach 
170363d16d06SMike Urbach   static void bindDerived(ClassTy &c) {
170463d16d06SMike Urbach     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
170563d16d06SMike Urbach   }
170663d16d06SMike Urbach 
1707436c6c9cSStella Laurenzo private:
1708436c6c9cSStella Laurenzo   PyOperationRef operation;
1709436c6c9cSStella Laurenzo };
1710436c6c9cSStella Laurenzo 
1711436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive
1712436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1713436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1714436c6c9cSStella Laurenzo /// operation.
1715436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1716436c6c9cSStella Laurenzo public:
1717436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResultList";
1718436c6c9cSStella Laurenzo 
1719436c6c9cSStella Laurenzo   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1720436c6c9cSStella Laurenzo                  intptr_t length = -1, intptr_t step = 1)
1721436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1722436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumResults(operation->get())
1723436c6c9cSStella Laurenzo                                : length,
1724436c6c9cSStella Laurenzo                   step),
1725436c6c9cSStella Laurenzo         operation(operation) {}
1726436c6c9cSStella Laurenzo 
1727436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1728436c6c9cSStella Laurenzo     operation->checkValid();
1729436c6c9cSStella Laurenzo     return mlirOperationGetNumResults(operation->get());
1730436c6c9cSStella Laurenzo   }
1731436c6c9cSStella Laurenzo 
1732436c6c9cSStella Laurenzo   PyOpResult getElement(intptr_t index) {
1733436c6c9cSStella Laurenzo     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1734436c6c9cSStella Laurenzo     return PyOpResult(value);
1735436c6c9cSStella Laurenzo   }
1736436c6c9cSStella Laurenzo 
1737436c6c9cSStella Laurenzo   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1738436c6c9cSStella Laurenzo     return PyOpResultList(operation, startIndex, length, step);
1739436c6c9cSStella Laurenzo   }
1740436c6c9cSStella Laurenzo 
1741ed9e52f3SAlex Zinenko   static void bindDerived(ClassTy &c) {
1742ed9e52f3SAlex Zinenko     c.def_property_readonly("types", [](PyOpResultList &self) {
1743ed9e52f3SAlex Zinenko       return getValueTypes(self, self.operation->getContext());
1744ed9e52f3SAlex Zinenko     });
1745ed9e52f3SAlex Zinenko   }
1746ed9e52f3SAlex Zinenko 
1747436c6c9cSStella Laurenzo private:
1748436c6c9cSStella Laurenzo   PyOperationRef operation;
1749436c6c9cSStella Laurenzo };
1750436c6c9cSStella Laurenzo 
1751436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
1752436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
1753436c6c9cSStella Laurenzo class PyOpAttributeMap {
1754436c6c9cSStella Laurenzo public:
1755436c6c9cSStella Laurenzo   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1756436c6c9cSStella Laurenzo 
1757436c6c9cSStella Laurenzo   PyAttribute dunderGetItemNamed(const std::string &name) {
1758436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1759436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
1760436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1761436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1762436c6c9cSStella Laurenzo                        "attempt to access a non-existent attribute");
1763436c6c9cSStella Laurenzo     }
1764436c6c9cSStella Laurenzo     return PyAttribute(operation->getContext(), attr);
1765436c6c9cSStella Laurenzo   }
1766436c6c9cSStella Laurenzo 
1767436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1768436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1769436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1770436c6c9cSStella Laurenzo                        "attempt to access out of bounds attribute");
1771436c6c9cSStella Laurenzo     }
1772436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
1773436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
1774436c6c9cSStella Laurenzo     return PyNamedAttribute(
1775436c6c9cSStella Laurenzo         namedAttr.attribute,
1776436c6c9cSStella Laurenzo         std::string(mlirIdentifierStr(namedAttr.name).data));
1777436c6c9cSStella Laurenzo   }
1778436c6c9cSStella Laurenzo 
1779436c6c9cSStella Laurenzo   void dunderSetItem(const std::string &name, PyAttribute attr) {
1780436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1781436c6c9cSStella Laurenzo                                     attr);
1782436c6c9cSStella Laurenzo   }
1783436c6c9cSStella Laurenzo 
1784436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
1785436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1786436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
1787436c6c9cSStella Laurenzo     if (!removed)
1788436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1789436c6c9cSStella Laurenzo                        "attempt to delete a non-existent attribute");
1790436c6c9cSStella Laurenzo   }
1791436c6c9cSStella Laurenzo 
1792436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1793436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
1794436c6c9cSStella Laurenzo   }
1795436c6c9cSStella Laurenzo 
1796436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
1797436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1798436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
1799436c6c9cSStella Laurenzo   }
1800436c6c9cSStella Laurenzo 
1801436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1802f05ff4f7SStella Laurenzo     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1803436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
1804436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
1805436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1806436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1807436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1808436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1809436c6c9cSStella Laurenzo   }
1810436c6c9cSStella Laurenzo 
1811436c6c9cSStella Laurenzo private:
1812436c6c9cSStella Laurenzo   PyOperationRef operation;
1813436c6c9cSStella Laurenzo };
1814436c6c9cSStella Laurenzo 
1815436c6c9cSStella Laurenzo } // end namespace
1816436c6c9cSStella Laurenzo 
1817436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1818436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
1819436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1820436c6c9cSStella Laurenzo 
1821436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) {
1822436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
18234acd8457SAlex Zinenko   // Mapping of MlirContext.
1824436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1825f05ff4f7SStella Laurenzo   py::class_<PyMlirContext>(m, "Context", py::module_local())
1826436c6c9cSStella Laurenzo       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1827436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1828436c6c9cSStella Laurenzo       .def("_get_context_again",
1829436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
1830436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1831436c6c9cSStella Laurenzo              return ref.releaseObject();
1832436c6c9cSStella Laurenzo            })
1833436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1834436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1835436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1836436c6c9cSStella Laurenzo                              &PyMlirContext::getCapsule)
1837436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1838436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
1839436c6c9cSStella Laurenzo       .def("__exit__", &PyMlirContext::contextExit)
1840436c6c9cSStella Laurenzo       .def_property_readonly_static(
1841436c6c9cSStella Laurenzo           "current",
1842436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1843436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
1844436c6c9cSStella Laurenzo             if (!context)
1845436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Context");
1846436c6c9cSStella Laurenzo             return context;
1847436c6c9cSStella Laurenzo           },
1848436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
1849436c6c9cSStella Laurenzo       .def_property_readonly(
1850436c6c9cSStella Laurenzo           "dialects",
1851436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1852436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
1853436c6c9cSStella Laurenzo       .def_property_readonly(
1854436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1855436c6c9cSStella Laurenzo           "Alias for 'dialect'")
1856436c6c9cSStella Laurenzo       .def(
1857436c6c9cSStella Laurenzo           "get_dialect_descriptor",
1858436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
1859436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
1860436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
1861436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
1862436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
1863436c6c9cSStella Laurenzo                                Twine("Dialect '") + name + "' not found");
1864436c6c9cSStella Laurenzo             }
1865436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
1866436c6c9cSStella Laurenzo           },
1867436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
1868436c6c9cSStella Laurenzo       .def_property(
1869436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
1870436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
1871436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
1872436c6c9cSStella Laurenzo           },
1873436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
1874436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
18759a9214faSStella Laurenzo           })
1876caa159f0SNicolas Vasilache       .def("enable_multithreading",
1877caa159f0SNicolas Vasilache            [](PyMlirContext &self, bool enable) {
1878caa159f0SNicolas Vasilache              mlirContextEnableMultithreading(self.get(), enable);
1879caa159f0SNicolas Vasilache            })
18809a9214faSStella Laurenzo       .def("is_registered_operation",
18819a9214faSStella Laurenzo            [](PyMlirContext &self, std::string &name) {
18829a9214faSStella Laurenzo              return mlirContextIsRegisteredOperation(
18839a9214faSStella Laurenzo                  self.get(), MlirStringRef{name.data(), name.size()});
1884436c6c9cSStella Laurenzo            });
1885436c6c9cSStella Laurenzo 
1886436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1887436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
1888436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1889f05ff4f7SStella Laurenzo   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1890436c6c9cSStella Laurenzo       .def_property_readonly("namespace",
1891436c6c9cSStella Laurenzo                              [](PyDialectDescriptor &self) {
1892436c6c9cSStella Laurenzo                                MlirStringRef ns =
1893436c6c9cSStella Laurenzo                                    mlirDialectGetNamespace(self.get());
1894436c6c9cSStella Laurenzo                                return py::str(ns.data, ns.length);
1895436c6c9cSStella Laurenzo                              })
1896436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
1897436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1898436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
1899436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
1900436c6c9cSStella Laurenzo         repr.append(">");
1901436c6c9cSStella Laurenzo         return repr;
1902436c6c9cSStella Laurenzo       });
1903436c6c9cSStella Laurenzo 
1904436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1905436c6c9cSStella Laurenzo   // Mapping of PyDialects
1906436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1907f05ff4f7SStella Laurenzo   py::class_<PyDialects>(m, "Dialects", py::module_local())
1908436c6c9cSStella Laurenzo       .def("__getitem__",
1909436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
1910436c6c9cSStella Laurenzo              MlirDialect dialect =
1911436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
1912436c6c9cSStella Laurenzo              py::object descriptor =
1913436c6c9cSStella Laurenzo                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1914436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
1915436c6c9cSStella Laurenzo            })
1916436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1917436c6c9cSStella Laurenzo         MlirDialect dialect =
1918436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
1919436c6c9cSStella Laurenzo         py::object descriptor =
1920436c6c9cSStella Laurenzo             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1921436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
1922436c6c9cSStella Laurenzo       });
1923436c6c9cSStella Laurenzo 
1924436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1925436c6c9cSStella Laurenzo   // Mapping of PyDialect
1926436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1927f05ff4f7SStella Laurenzo   py::class_<PyDialect>(m, "Dialect", py::module_local())
1928436c6c9cSStella Laurenzo       .def(py::init<py::object>(), "descriptor")
1929436c6c9cSStella Laurenzo       .def_property_readonly(
1930436c6c9cSStella Laurenzo           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1931436c6c9cSStella Laurenzo       .def("__repr__", [](py::object self) {
1932436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
1933436c6c9cSStella Laurenzo         return py::str("<Dialect ") +
1934436c6c9cSStella Laurenzo                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1935436c6c9cSStella Laurenzo                clazz.attr("__module__") + py::str(".") +
1936436c6c9cSStella Laurenzo                clazz.attr("__name__") + py::str(")>");
1937436c6c9cSStella Laurenzo       });
1938436c6c9cSStella Laurenzo 
1939436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1940436c6c9cSStella Laurenzo   // Mapping of Location
1941436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1942f05ff4f7SStella Laurenzo   py::class_<PyLocation>(m, "Location", py::module_local())
1943436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1944436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1945436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
1946436c6c9cSStella Laurenzo       .def("__exit__", &PyLocation::contextExit)
1947436c6c9cSStella Laurenzo       .def("__eq__",
1948436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
1949436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
1950436c6c9cSStella Laurenzo            })
1951436c6c9cSStella Laurenzo       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1952436c6c9cSStella Laurenzo       .def_property_readonly_static(
1953436c6c9cSStella Laurenzo           "current",
1954436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1955436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
1956436c6c9cSStella Laurenzo             if (!loc)
1957436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Location");
1958436c6c9cSStella Laurenzo             return loc;
1959436c6c9cSStella Laurenzo           },
1960436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
1961436c6c9cSStella Laurenzo       .def_static(
1962436c6c9cSStella Laurenzo           "unknown",
1963436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
1964436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
1965436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
1966436c6c9cSStella Laurenzo           },
1967436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
1968436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
1969436c6c9cSStella Laurenzo       .def_static(
1970*e67cbbefSJacques Pienaar           "callsite",
1971*e67cbbefSJacques Pienaar           [](PyLocation callee, const std::vector<PyLocation> &frames,
1972*e67cbbefSJacques Pienaar              DefaultingPyMlirContext context) {
1973*e67cbbefSJacques Pienaar             if (frames.empty())
1974*e67cbbefSJacques Pienaar               throw py::value_error("No caller frames provided");
1975*e67cbbefSJacques Pienaar             MlirLocation caller = frames.back().get();
1976*e67cbbefSJacques Pienaar             for (PyLocation frame :
1977*e67cbbefSJacques Pienaar                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
1978*e67cbbefSJacques Pienaar               caller = mlirLocationCallSiteGet(frame.get(), caller);
1979*e67cbbefSJacques Pienaar             return PyLocation(context->getRef(),
1980*e67cbbefSJacques Pienaar                               mlirLocationCallSiteGet(callee.get(), caller));
1981*e67cbbefSJacques Pienaar           },
1982*e67cbbefSJacques Pienaar           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
1983*e67cbbefSJacques Pienaar           kContextGetCallSiteLocationDocstring)
1984*e67cbbefSJacques Pienaar       .def_static(
1985436c6c9cSStella Laurenzo           "file",
1986436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
1987436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
1988436c6c9cSStella Laurenzo             return PyLocation(
1989436c6c9cSStella Laurenzo                 context->getRef(),
1990436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
1991436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
1992436c6c9cSStella Laurenzo           },
1993436c6c9cSStella Laurenzo           py::arg("filename"), py::arg("line"), py::arg("col"),
1994436c6c9cSStella Laurenzo           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
199504d76d36SJacques Pienaar       .def_static(
199604d76d36SJacques Pienaar           "name",
199704d76d36SJacques Pienaar           [](std::string name, llvm::Optional<PyLocation> childLoc,
199804d76d36SJacques Pienaar              DefaultingPyMlirContext context) {
199904d76d36SJacques Pienaar             return PyLocation(
200004d76d36SJacques Pienaar                 context->getRef(),
200104d76d36SJacques Pienaar                 mlirLocationNameGet(
200204d76d36SJacques Pienaar                     context->get(), toMlirStringRef(name),
200304d76d36SJacques Pienaar                     childLoc ? childLoc->get()
200404d76d36SJacques Pienaar                              : mlirLocationUnknownGet(context->get())));
200504d76d36SJacques Pienaar           },
200604d76d36SJacques Pienaar           py::arg("name"), py::arg("childLoc") = py::none(),
200704d76d36SJacques Pienaar           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2008436c6c9cSStella Laurenzo       .def_property_readonly(
2009436c6c9cSStella Laurenzo           "context",
2010436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
2011436c6c9cSStella Laurenzo           "Context that owns the Location")
2012436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
2013436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2014436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
2015436c6c9cSStella Laurenzo                           printAccum.getUserData());
2016436c6c9cSStella Laurenzo         return printAccum.join();
2017436c6c9cSStella Laurenzo       });
2018436c6c9cSStella Laurenzo 
2019436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2020436c6c9cSStella Laurenzo   // Mapping of Module
2021436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2022f05ff4f7SStella Laurenzo   py::class_<PyModule>(m, "Module", py::module_local())
2023436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2024436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2025436c6c9cSStella Laurenzo       .def_static(
2026436c6c9cSStella Laurenzo           "parse",
2027436c6c9cSStella Laurenzo           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2028436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
2029436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
2030436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2031436c6c9cSStella Laurenzo             // in C API.
2032436c6c9cSStella Laurenzo             if (mlirModuleIsNull(module)) {
2033436c6c9cSStella Laurenzo               throw SetPyError(
2034436c6c9cSStella Laurenzo                   PyExc_ValueError,
2035436c6c9cSStella Laurenzo                   "Unable to parse module assembly (see diagnostics)");
2036436c6c9cSStella Laurenzo             }
2037436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2038436c6c9cSStella Laurenzo           },
2039436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2040436c6c9cSStella Laurenzo           kModuleParseDocstring)
2041436c6c9cSStella Laurenzo       .def_static(
2042436c6c9cSStella Laurenzo           "create",
2043436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
2044436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
2045436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
2046436c6c9cSStella Laurenzo           },
2047436c6c9cSStella Laurenzo           py::arg("loc") = py::none(), "Creates an empty module")
2048436c6c9cSStella Laurenzo       .def_property_readonly(
2049436c6c9cSStella Laurenzo           "context",
2050436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
2051436c6c9cSStella Laurenzo           "Context that created the Module")
2052436c6c9cSStella Laurenzo       .def_property_readonly(
2053436c6c9cSStella Laurenzo           "operation",
2054436c6c9cSStella Laurenzo           [](PyModule &self) {
2055436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
2056436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
2057436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
2058436c6c9cSStella Laurenzo                 .releaseObject();
2059436c6c9cSStella Laurenzo           },
2060436c6c9cSStella Laurenzo           "Accesses the module as an operation")
2061436c6c9cSStella Laurenzo       .def_property_readonly(
2062436c6c9cSStella Laurenzo           "body",
2063436c6c9cSStella Laurenzo           [](PyModule &self) {
2064436c6c9cSStella Laurenzo             PyOperationRef module_op = PyOperation::forOperation(
2065436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
2066436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
2067436c6c9cSStella Laurenzo             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2068436c6c9cSStella Laurenzo             return returnBlock;
2069436c6c9cSStella Laurenzo           },
2070436c6c9cSStella Laurenzo           "Return the block for this module")
2071436c6c9cSStella Laurenzo       .def(
2072436c6c9cSStella Laurenzo           "dump",
2073436c6c9cSStella Laurenzo           [](PyModule &self) {
2074436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
2075436c6c9cSStella Laurenzo           },
2076436c6c9cSStella Laurenzo           kDumpDocstring)
2077436c6c9cSStella Laurenzo       .def(
2078436c6c9cSStella Laurenzo           "__str__",
2079436c6c9cSStella Laurenzo           [](PyModule &self) {
2080436c6c9cSStella Laurenzo             MlirOperation operation = mlirModuleGetOperation(self.get());
2081436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2082436c6c9cSStella Laurenzo             mlirOperationPrint(operation, printAccum.getCallback(),
2083436c6c9cSStella Laurenzo                                printAccum.getUserData());
2084436c6c9cSStella Laurenzo             return printAccum.join();
2085436c6c9cSStella Laurenzo           },
2086436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
2087436c6c9cSStella Laurenzo 
2088436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2089436c6c9cSStella Laurenzo   // Mapping of Operation.
2090436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2091f05ff4f7SStella Laurenzo   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
20921fb2e842SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
20931fb2e842SStella Laurenzo                              [](PyOperationBase &self) {
20941fb2e842SStella Laurenzo                                return self.getOperation().getCapsule();
20951fb2e842SStella Laurenzo                              })
2096436c6c9cSStella Laurenzo       .def("__eq__",
2097436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
2098436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
2099436c6c9cSStella Laurenzo            })
2100436c6c9cSStella Laurenzo       .def("__eq__",
2101436c6c9cSStella Laurenzo            [](PyOperationBase &self, py::object other) { return false; })
2102436c6c9cSStella Laurenzo       .def_property_readonly("attributes",
2103436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2104436c6c9cSStella Laurenzo                                return PyOpAttributeMap(
2105436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2106436c6c9cSStella Laurenzo                              })
2107436c6c9cSStella Laurenzo       .def_property_readonly("operands",
2108436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2109436c6c9cSStella Laurenzo                                return PyOpOperandList(
2110436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2111436c6c9cSStella Laurenzo                              })
2112436c6c9cSStella Laurenzo       .def_property_readonly("regions",
2113436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2114436c6c9cSStella Laurenzo                                return PyRegionList(
2115436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2116436c6c9cSStella Laurenzo                              })
2117436c6c9cSStella Laurenzo       .def_property_readonly(
2118436c6c9cSStella Laurenzo           "results",
2119436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2120436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
2121436c6c9cSStella Laurenzo           },
2122436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
2123436c6c9cSStella Laurenzo       .def_property_readonly(
2124436c6c9cSStella Laurenzo           "result",
2125436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2126436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
2127436c6c9cSStella Laurenzo             auto numResults = mlirOperationGetNumResults(operation);
2128436c6c9cSStella Laurenzo             if (numResults != 1) {
2129436c6c9cSStella Laurenzo               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2130436c6c9cSStella Laurenzo               throw SetPyError(
2131436c6c9cSStella Laurenzo                   PyExc_ValueError,
2132436c6c9cSStella Laurenzo                   Twine("Cannot call .result on operation ") +
2133436c6c9cSStella Laurenzo                       StringRef(name.data, name.length) + " which has " +
2134436c6c9cSStella Laurenzo                       Twine(numResults) +
2135436c6c9cSStella Laurenzo                       " results (it is only valid for operations with a "
2136436c6c9cSStella Laurenzo                       "single result)");
2137436c6c9cSStella Laurenzo             }
2138436c6c9cSStella Laurenzo             return PyOpResult(operation.getRef(),
2139436c6c9cSStella Laurenzo                               mlirOperationGetResult(operation, 0));
2140436c6c9cSStella Laurenzo           },
2141436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
2142436c6c9cSStella Laurenzo           "otherwise).")
2143436c6c9cSStella Laurenzo       .def("__iter__",
2144436c6c9cSStella Laurenzo            [](PyOperationBase &self) {
2145436c6c9cSStella Laurenzo              return PyRegionIterator(self.getOperation().getRef());
2146436c6c9cSStella Laurenzo            })
2147436c6c9cSStella Laurenzo       .def(
2148436c6c9cSStella Laurenzo           "__str__",
2149436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2150436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
2151436c6c9cSStella Laurenzo                                /*largeElementsLimit=*/llvm::None,
2152436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
2153436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
2154436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
2155436c6c9cSStella Laurenzo                                /*useLocalScope=*/false);
2156436c6c9cSStella Laurenzo           },
2157436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
2158436c6c9cSStella Laurenzo       .def("print", &PyOperationBase::print,
2159436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
2160436c6c9cSStella Laurenzo            py::arg("file") = py::none(), py::arg("binary") = false,
2161436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2162436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2163436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2164436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2165436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2166436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
2167436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
2168436c6c9cSStella Laurenzo            py::arg("binary") = false,
2169436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2170436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2171436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2172436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2173436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2174436c6c9cSStella Laurenzo       .def(
2175436c6c9cSStella Laurenzo           "verify",
2176436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2177436c6c9cSStella Laurenzo             return mlirOperationVerify(self.getOperation());
2178436c6c9cSStella Laurenzo           },
2179436c6c9cSStella Laurenzo           "Verify the operation and return true if it passes, false if it "
2180436c6c9cSStella Laurenzo           "fails.");
2181436c6c9cSStella Laurenzo 
2182f05ff4f7SStella Laurenzo   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2183436c6c9cSStella Laurenzo       .def_static("create", &PyOperation::create, py::arg("name"),
2184436c6c9cSStella Laurenzo                   py::arg("results") = py::none(),
2185436c6c9cSStella Laurenzo                   py::arg("operands") = py::none(),
2186436c6c9cSStella Laurenzo                   py::arg("attributes") = py::none(),
2187436c6c9cSStella Laurenzo                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2188436c6c9cSStella Laurenzo                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2189436c6c9cSStella Laurenzo                   kOperationCreateDocstring)
2190c65bb760SJohn Demme       .def_property_readonly("parent",
21911689dadeSJohn Demme                              [](PyOperation &self) -> py::object {
21921689dadeSJohn Demme                                auto parent = self.getParentOperation();
21931689dadeSJohn Demme                                if (parent)
21941689dadeSJohn Demme                                  return parent->getObject();
21951689dadeSJohn Demme                                return py::none();
2196c65bb760SJohn Demme                              })
219749745f87SMike Urbach       .def("erase", &PyOperation::erase)
21980126e906SJohn Demme       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
21990126e906SJohn Demme                              &PyOperation::getCapsule)
22000126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2201436c6c9cSStella Laurenzo       .def_property_readonly("name",
2202436c6c9cSStella Laurenzo                              [](PyOperation &self) {
220349745f87SMike Urbach                                self.checkValid();
2204436c6c9cSStella Laurenzo                                MlirOperation operation = self.get();
2205436c6c9cSStella Laurenzo                                MlirStringRef name = mlirIdentifierStr(
2206436c6c9cSStella Laurenzo                                    mlirOperationGetName(operation));
2207436c6c9cSStella Laurenzo                                return py::str(name.data, name.length);
2208436c6c9cSStella Laurenzo                              })
2209436c6c9cSStella Laurenzo       .def_property_readonly(
2210436c6c9cSStella Laurenzo           "context",
221149745f87SMike Urbach           [](PyOperation &self) {
221249745f87SMike Urbach             self.checkValid();
221349745f87SMike Urbach             return self.getContext().getObject();
221449745f87SMike Urbach           },
2215436c6c9cSStella Laurenzo           "Context that owns the Operation")
2216436c6c9cSStella Laurenzo       .def_property_readonly("opview", &PyOperation::createOpView);
2217436c6c9cSStella Laurenzo 
2218436c6c9cSStella Laurenzo   auto opViewClass =
2219f05ff4f7SStella Laurenzo       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2220436c6c9cSStella Laurenzo           .def(py::init<py::object>())
2221436c6c9cSStella Laurenzo           .def_property_readonly("operation", &PyOpView::getOperationObject)
2222436c6c9cSStella Laurenzo           .def_property_readonly(
2223436c6c9cSStella Laurenzo               "context",
2224436c6c9cSStella Laurenzo               [](PyOpView &self) {
2225436c6c9cSStella Laurenzo                 return self.getOperation().getContext().getObject();
2226436c6c9cSStella Laurenzo               },
2227436c6c9cSStella Laurenzo               "Context that owns the Operation")
2228436c6c9cSStella Laurenzo           .def("__str__", [](PyOpView &self) {
2229436c6c9cSStella Laurenzo             return py::str(self.getOperationObject());
2230436c6c9cSStella Laurenzo           });
2231436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2232436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2233436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2234436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
2235436c6c9cSStella Laurenzo       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2236436c6c9cSStella Laurenzo       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2237436c6c9cSStella Laurenzo       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2238436c6c9cSStella Laurenzo       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2239436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
2240436c6c9cSStella Laurenzo 
2241436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2242436c6c9cSStella Laurenzo   // Mapping of PyRegion.
2243436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2244f05ff4f7SStella Laurenzo   py::class_<PyRegion>(m, "Region", py::module_local())
2245436c6c9cSStella Laurenzo       .def_property_readonly(
2246436c6c9cSStella Laurenzo           "blocks",
2247436c6c9cSStella Laurenzo           [](PyRegion &self) {
2248436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
2249436c6c9cSStella Laurenzo           },
2250436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
2251436c6c9cSStella Laurenzo       .def(
2252436c6c9cSStella Laurenzo           "__iter__",
2253436c6c9cSStella Laurenzo           [](PyRegion &self) {
2254436c6c9cSStella Laurenzo             self.checkValid();
2255436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2256436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
2257436c6c9cSStella Laurenzo           },
2258436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
2259436c6c9cSStella Laurenzo       .def("__eq__",
2260436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
2261436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2262436c6c9cSStella Laurenzo            })
2263436c6c9cSStella Laurenzo       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2264436c6c9cSStella Laurenzo 
2265436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2266436c6c9cSStella Laurenzo   // Mapping of PyBlock.
2267436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2268f05ff4f7SStella Laurenzo   py::class_<PyBlock>(m, "Block", py::module_local())
2269436c6c9cSStella Laurenzo       .def_property_readonly(
227096fbd5cdSJohn Demme           "owner",
227196fbd5cdSJohn Demme           [](PyBlock &self) {
227296fbd5cdSJohn Demme             return self.getParentOperation()->createOpView();
227396fbd5cdSJohn Demme           },
227496fbd5cdSJohn Demme           "Returns the owning operation of this block.")
227596fbd5cdSJohn Demme       .def_property_readonly(
22768e6c55c9SStella Laurenzo           "region",
22778e6c55c9SStella Laurenzo           [](PyBlock &self) {
22788e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
22798e6c55c9SStella Laurenzo             return PyRegion(self.getParentOperation(), region);
22808e6c55c9SStella Laurenzo           },
22818e6c55c9SStella Laurenzo           "Returns the owning region of this block.")
22828e6c55c9SStella Laurenzo       .def_property_readonly(
2283436c6c9cSStella Laurenzo           "arguments",
2284436c6c9cSStella Laurenzo           [](PyBlock &self) {
2285436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
2286436c6c9cSStella Laurenzo           },
2287436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
2288436c6c9cSStella Laurenzo       .def_property_readonly(
2289436c6c9cSStella Laurenzo           "operations",
2290436c6c9cSStella Laurenzo           [](PyBlock &self) {
2291436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
2292436c6c9cSStella Laurenzo           },
2293436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
2294436c6c9cSStella Laurenzo       .def(
22958e6c55c9SStella Laurenzo           "create_before",
22968e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
22978e6c55c9SStella Laurenzo             self.checkValid();
22988e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
22998e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
23008e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
23018e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
23028e6c55c9SStella Laurenzo             }
23038e6c55c9SStella Laurenzo 
23048e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
23058e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
23068e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
23078e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
23088e6c55c9SStella Laurenzo           },
23098e6c55c9SStella Laurenzo           "Creates and returns a new Block before this block "
23108e6c55c9SStella Laurenzo           "(with given argument types).")
23118e6c55c9SStella Laurenzo       .def(
23128e6c55c9SStella Laurenzo           "create_after",
23138e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
23148e6c55c9SStella Laurenzo             self.checkValid();
23158e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
23168e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
23178e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
23188e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
23198e6c55c9SStella Laurenzo             }
23208e6c55c9SStella Laurenzo 
23218e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
23228e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
23238e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
23248e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
23258e6c55c9SStella Laurenzo           },
23268e6c55c9SStella Laurenzo           "Creates and returns a new Block after this block "
23278e6c55c9SStella Laurenzo           "(with given argument types).")
23288e6c55c9SStella Laurenzo       .def(
2329436c6c9cSStella Laurenzo           "__iter__",
2330436c6c9cSStella Laurenzo           [](PyBlock &self) {
2331436c6c9cSStella Laurenzo             self.checkValid();
2332436c6c9cSStella Laurenzo             MlirOperation firstOperation =
2333436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
2334436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
2335436c6c9cSStella Laurenzo                                        firstOperation);
2336436c6c9cSStella Laurenzo           },
2337436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
2338436c6c9cSStella Laurenzo       .def("__eq__",
2339436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
2340436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2341436c6c9cSStella Laurenzo            })
2342436c6c9cSStella Laurenzo       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2343436c6c9cSStella Laurenzo       .def(
2344436c6c9cSStella Laurenzo           "__str__",
2345436c6c9cSStella Laurenzo           [](PyBlock &self) {
2346436c6c9cSStella Laurenzo             self.checkValid();
2347436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2348436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
2349436c6c9cSStella Laurenzo                            printAccum.getUserData());
2350436c6c9cSStella Laurenzo             return printAccum.join();
2351436c6c9cSStella Laurenzo           },
2352436c6c9cSStella Laurenzo           "Returns the assembly form of the block.");
2353436c6c9cSStella Laurenzo 
2354436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2355436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
2356436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2357436c6c9cSStella Laurenzo 
2358f05ff4f7SStella Laurenzo   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2359436c6c9cSStella Laurenzo       .def(py::init<PyBlock &>(), py::arg("block"),
2360436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
2361436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
2362436c6c9cSStella Laurenzo       .def("__exit__", &PyInsertionPoint::contextExit)
2363436c6c9cSStella Laurenzo       .def_property_readonly_static(
2364436c6c9cSStella Laurenzo           "current",
2365436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2366436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2367436c6c9cSStella Laurenzo             if (!ip)
2368436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2369436c6c9cSStella Laurenzo             return ip;
2370436c6c9cSStella Laurenzo           },
2371436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
2372436c6c9cSStella Laurenzo           "ValueError if none has been set")
2373436c6c9cSStella Laurenzo       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2374436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
2375436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2376436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts at the beginning of the block.")
2377436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2378436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts before the block terminator.")
2379436c6c9cSStella Laurenzo       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
23808e6c55c9SStella Laurenzo            "Inserts an operation.")
23818e6c55c9SStella Laurenzo       .def_property_readonly(
23828e6c55c9SStella Laurenzo           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
23838e6c55c9SStella Laurenzo           "Returns the block that this InsertionPoint points to.");
2384436c6c9cSStella Laurenzo 
2385436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2386436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
2387436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2388f05ff4f7SStella Laurenzo   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2389b57d6fe4SStella Laurenzo       // Delegate to the PyAttribute copy constructor, which will also lifetime
2390b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirAttribute.
2391b57d6fe4SStella Laurenzo       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2392b57d6fe4SStella Laurenzo            "Casts the passed attribute to the generic Attribute")
2393436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2394436c6c9cSStella Laurenzo                              &PyAttribute::getCapsule)
2395436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2396436c6c9cSStella Laurenzo       .def_static(
2397436c6c9cSStella Laurenzo           "parse",
2398436c6c9cSStella Laurenzo           [](std::string attrSpec, DefaultingPyMlirContext context) {
2399436c6c9cSStella Laurenzo             MlirAttribute type = mlirAttributeParseGet(
2400436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
2401436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2402436c6c9cSStella Laurenzo             // in C API.
2403436c6c9cSStella Laurenzo             if (mlirAttributeIsNull(type)) {
2404436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2405436c6c9cSStella Laurenzo                                Twine("Unable to parse attribute: '") +
2406436c6c9cSStella Laurenzo                                    attrSpec + "'");
2407436c6c9cSStella Laurenzo             }
2408436c6c9cSStella Laurenzo             return PyAttribute(context->getRef(), type);
2409436c6c9cSStella Laurenzo           },
2410436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2411436c6c9cSStella Laurenzo           "Parses an attribute from an assembly form")
2412436c6c9cSStella Laurenzo       .def_property_readonly(
2413436c6c9cSStella Laurenzo           "context",
2414436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
2415436c6c9cSStella Laurenzo           "Context that owns the Attribute")
2416436c6c9cSStella Laurenzo       .def_property_readonly("type",
2417436c6c9cSStella Laurenzo                              [](PyAttribute &self) {
2418436c6c9cSStella Laurenzo                                return PyType(self.getContext()->getRef(),
2419436c6c9cSStella Laurenzo                                              mlirAttributeGetType(self));
2420436c6c9cSStella Laurenzo                              })
2421436c6c9cSStella Laurenzo       .def(
2422436c6c9cSStella Laurenzo           "get_named",
2423436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
2424436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
2425436c6c9cSStella Laurenzo           },
2426436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2427436c6c9cSStella Laurenzo       .def("__eq__",
2428436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2429436c6c9cSStella Laurenzo       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
243047cc166bSJohn Demme       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2431436c6c9cSStella Laurenzo       .def(
2432436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2433436c6c9cSStella Laurenzo           kDumpDocstring)
2434436c6c9cSStella Laurenzo       .def(
2435436c6c9cSStella Laurenzo           "__str__",
2436436c6c9cSStella Laurenzo           [](PyAttribute &self) {
2437436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2438436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
2439436c6c9cSStella Laurenzo                                printAccum.getUserData());
2440436c6c9cSStella Laurenzo             return printAccum.join();
2441436c6c9cSStella Laurenzo           },
2442436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
2443436c6c9cSStella Laurenzo       .def("__repr__", [](PyAttribute &self) {
2444436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2445436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2446436c6c9cSStella Laurenzo         // However, attribute values are generally considered useful and are
2447436c6c9cSStella Laurenzo         // printed. This may need to be re-evaluated if debug dumps end up
2448436c6c9cSStella Laurenzo         // being excessive.
2449436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2450436c6c9cSStella Laurenzo         printAccum.parts.append("Attribute(");
2451436c6c9cSStella Laurenzo         mlirAttributePrint(self, printAccum.getCallback(),
2452436c6c9cSStella Laurenzo                            printAccum.getUserData());
2453436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2454436c6c9cSStella Laurenzo         return printAccum.join();
2455436c6c9cSStella Laurenzo       });
2456436c6c9cSStella Laurenzo 
2457436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2458436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
2459436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2460f05ff4f7SStella Laurenzo   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2461436c6c9cSStella Laurenzo       .def("__repr__",
2462436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
2463436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
2464436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
2465436c6c9cSStella Laurenzo              printAccum.parts.append(
2466436c6c9cSStella Laurenzo                  mlirIdentifierStr(self.namedAttr.name).data);
2467436c6c9cSStella Laurenzo              printAccum.parts.append("=");
2468436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
2469436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
2470436c6c9cSStella Laurenzo                                 printAccum.getUserData());
2471436c6c9cSStella Laurenzo              printAccum.parts.append(")");
2472436c6c9cSStella Laurenzo              return printAccum.join();
2473436c6c9cSStella Laurenzo            })
2474436c6c9cSStella Laurenzo       .def_property_readonly(
2475436c6c9cSStella Laurenzo           "name",
2476436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2477436c6c9cSStella Laurenzo             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2478436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
2479436c6c9cSStella Laurenzo           },
2480436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
2481436c6c9cSStella Laurenzo       .def_property_readonly(
2482436c6c9cSStella Laurenzo           "attr",
2483436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2484436c6c9cSStella Laurenzo             // TODO: When named attribute is removed/refactored, also remove
2485436c6c9cSStella Laurenzo             // this constructor (it does an inefficient table lookup).
2486436c6c9cSStella Laurenzo             auto contextRef = PyMlirContext::forContext(
2487436c6c9cSStella Laurenzo                 mlirAttributeGetContext(self.namedAttr.attribute));
2488436c6c9cSStella Laurenzo             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2489436c6c9cSStella Laurenzo           },
2490436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(),
2491436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
2492436c6c9cSStella Laurenzo 
2493436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2494436c6c9cSStella Laurenzo   // Mapping of PyType.
2495436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2496f05ff4f7SStella Laurenzo   py::class_<PyType>(m, "Type", py::module_local())
2497b57d6fe4SStella Laurenzo       // Delegate to the PyType copy constructor, which will also lifetime
2498b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirType.
2499b57d6fe4SStella Laurenzo       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2500b57d6fe4SStella Laurenzo            "Casts the passed type to the generic Type")
2501436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2502436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2503436c6c9cSStella Laurenzo       .def_static(
2504436c6c9cSStella Laurenzo           "parse",
2505436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
2506436c6c9cSStella Laurenzo             MlirType type =
2507436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2508436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2509436c6c9cSStella Laurenzo             // in C API.
2510436c6c9cSStella Laurenzo             if (mlirTypeIsNull(type)) {
2511436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2512436c6c9cSStella Laurenzo                                Twine("Unable to parse type: '") + typeSpec +
2513436c6c9cSStella Laurenzo                                    "'");
2514436c6c9cSStella Laurenzo             }
2515436c6c9cSStella Laurenzo             return PyType(context->getRef(), type);
2516436c6c9cSStella Laurenzo           },
2517436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2518436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
2519436c6c9cSStella Laurenzo       .def_property_readonly(
2520436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
2521436c6c9cSStella Laurenzo           "Context that owns the Type")
2522436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2523436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, py::object &other) { return false; })
252447cc166bSJohn Demme       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2525436c6c9cSStella Laurenzo       .def(
2526436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2527436c6c9cSStella Laurenzo       .def(
2528436c6c9cSStella Laurenzo           "__str__",
2529436c6c9cSStella Laurenzo           [](PyType &self) {
2530436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2531436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
2532436c6c9cSStella Laurenzo                           printAccum.getUserData());
2533436c6c9cSStella Laurenzo             return printAccum.join();
2534436c6c9cSStella Laurenzo           },
2535436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
2536436c6c9cSStella Laurenzo       .def("__repr__", [](PyType &self) {
2537436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2538436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2539436c6c9cSStella Laurenzo         // However, types are an exception as they typically have compact
2540436c6c9cSStella Laurenzo         // assembly forms and printing them is useful.
2541436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2542436c6c9cSStella Laurenzo         printAccum.parts.append("Type(");
2543436c6c9cSStella Laurenzo         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2544436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2545436c6c9cSStella Laurenzo         return printAccum.join();
2546436c6c9cSStella Laurenzo       });
2547436c6c9cSStella Laurenzo 
2548436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2549436c6c9cSStella Laurenzo   // Mapping of Value.
2550436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2551f05ff4f7SStella Laurenzo   py::class_<PyValue>(m, "Value", py::module_local())
25523f3d1c90SMike Urbach       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
25533f3d1c90SMike Urbach       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2554436c6c9cSStella Laurenzo       .def_property_readonly(
2555436c6c9cSStella Laurenzo           "context",
2556436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2557436c6c9cSStella Laurenzo           "Context in which the value lives.")
2558436c6c9cSStella Laurenzo       .def(
2559436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2560436c6c9cSStella Laurenzo           kDumpDocstring)
25615664c5e2SJohn Demme       .def_property_readonly(
25625664c5e2SJohn Demme           "owner",
25635664c5e2SJohn Demme           [](PyValue &self) {
25645664c5e2SJohn Demme             assert(mlirOperationEqual(self.getParentOperation()->get(),
25655664c5e2SJohn Demme                                       mlirOpResultGetOwner(self.get())) &&
25665664c5e2SJohn Demme                    "expected the owner of the value in Python to match that in "
25675664c5e2SJohn Demme                    "the IR");
25685664c5e2SJohn Demme             return self.getParentOperation().getObject();
25695664c5e2SJohn Demme           })
2570436c6c9cSStella Laurenzo       .def("__eq__",
2571436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
2572436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2573436c6c9cSStella Laurenzo            })
2574436c6c9cSStella Laurenzo       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2575436c6c9cSStella Laurenzo       .def(
2576436c6c9cSStella Laurenzo           "__str__",
2577436c6c9cSStella Laurenzo           [](PyValue &self) {
2578436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2579436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
2580436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
2581436c6c9cSStella Laurenzo                            printAccum.getUserData());
2582436c6c9cSStella Laurenzo             printAccum.parts.append(")");
2583436c6c9cSStella Laurenzo             return printAccum.join();
2584436c6c9cSStella Laurenzo           },
2585436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
2586436c6c9cSStella Laurenzo       .def_property_readonly("type", [](PyValue &self) {
2587436c6c9cSStella Laurenzo         return PyType(self.getParentOperation()->getContext(),
2588436c6c9cSStella Laurenzo                       mlirValueGetType(self.get()));
2589436c6c9cSStella Laurenzo       });
2590436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
2591436c6c9cSStella Laurenzo   PyOpResult::bind(m);
2592436c6c9cSStella Laurenzo 
2593436c6c9cSStella Laurenzo   // Container bindings.
2594436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
2595436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
2596436c6c9cSStella Laurenzo   PyBlockList::bind(m);
2597436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
2598436c6c9cSStella Laurenzo   PyOperationList::bind(m);
2599436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
2600436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
2601436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
2602436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
2603436c6c9cSStella Laurenzo   PyRegionList::bind(m);
26044acd8457SAlex Zinenko 
26054acd8457SAlex Zinenko   // Debug bindings.
26064acd8457SAlex Zinenko   PyGlobalDebugFlag::bind(m);
2607436c6c9cSStella Laurenzo }
2608