1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "Globals.h"
12436c6c9cSStella Laurenzo #include "PybindUtils.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
17436c6c9cSStella Laurenzo #include "mlir-c/Registration.h"
18436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
19436c6c9cSStella Laurenzo #include <pybind11/stl.h>
20436c6c9cSStella Laurenzo 
21436c6c9cSStella Laurenzo namespace py = pybind11;
22436c6c9cSStella Laurenzo using namespace mlir;
23436c6c9cSStella Laurenzo using namespace mlir::python;
24436c6c9cSStella Laurenzo 
25436c6c9cSStella Laurenzo using llvm::SmallVector;
26436c6c9cSStella Laurenzo using llvm::StringRef;
27436c6c9cSStella Laurenzo using llvm::Twine;
28436c6c9cSStella Laurenzo 
29436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
30436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
31436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
32436c6c9cSStella Laurenzo 
33436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
34436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
35436c6c9cSStella Laurenzo 
36436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed.
37436c6c9cSStella Laurenzo 
38436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
39436c6c9cSStella Laurenzo )";
40436c6c9cSStella Laurenzo 
41436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
42436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
43436c6c9cSStella Laurenzo 
44436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
45436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
46436c6c9cSStella Laurenzo 
47436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails.
48436c6c9cSStella Laurenzo 
49436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
50436c6c9cSStella Laurenzo )";
51436c6c9cSStella Laurenzo 
52436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
53436c6c9cSStella Laurenzo     R"(Creates a new operation.
54436c6c9cSStella Laurenzo 
55436c6c9cSStella Laurenzo Args:
56436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
57436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
58436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
59436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
60436c6c9cSStella Laurenzo   regions: Number of regions to create.
61436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
62436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
63436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
64436c6c9cSStella Laurenzo     context manager).
65436c6c9cSStella Laurenzo Returns:
66436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
67436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
68436c6c9cSStella Laurenzo )";
69436c6c9cSStella Laurenzo 
70436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
71436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
72436c6c9cSStella Laurenzo 
73436c6c9cSStella Laurenzo Args:
74436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
75436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
76436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
77436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
78436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
79436c6c9cSStella Laurenzo     to False.
80436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
81436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
82436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
83436c6c9cSStella Laurenzo     ops. Defaults to False.
84436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
85436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
86436c6c9cSStella Laurenzo     module prints.
87436c6c9cSStella Laurenzo )";
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
90436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
91436c6c9cSStella Laurenzo 
92436c6c9cSStella Laurenzo Args:
93436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
94436c6c9cSStella Laurenzo     False.
95436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
96436c6c9cSStella Laurenzo     configuring the printout.
97436c6c9cSStella Laurenzo Returns:
98436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
99436c6c9cSStella Laurenzo   argument.
100436c6c9cSStella Laurenzo )";
101436c6c9cSStella Laurenzo 
102436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
103436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
104436c6c9cSStella Laurenzo 
105436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
106436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
107436c6c9cSStella Laurenzo customize behavior.
108436c6c9cSStella Laurenzo )";
109436c6c9cSStella Laurenzo 
110436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
111436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
112436c6c9cSStella Laurenzo 
113436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
114436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
115436c6c9cSStella Laurenzo 
116436c6c9cSStella Laurenzo Returns:
117436c6c9cSStella Laurenzo   The created block.
118436c6c9cSStella Laurenzo )";
119436c6c9cSStella Laurenzo 
120436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
121436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
122436c6c9cSStella Laurenzo 
123436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
124436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
125436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
126436c6c9cSStella Laurenzo )";
127436c6c9cSStella Laurenzo 
128436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
129436c6c9cSStella Laurenzo // Utilities.
130436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
131436c6c9cSStella Laurenzo 
132436c6c9cSStella Laurenzo // Helper for creating an @classmethod.
133436c6c9cSStella Laurenzo template <class Func, typename... Args>
134436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) {
135436c6c9cSStella Laurenzo   py::object cf = py::cpp_function(f, args...);
136436c6c9cSStella Laurenzo   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
137436c6c9cSStella Laurenzo }
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo static py::object
140436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
141436c6c9cSStella Laurenzo                            py::object dialectDescriptor) {
142436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
143436c6c9cSStella Laurenzo   if (!dialectClass) {
144436c6c9cSStella Laurenzo     // Use the base class.
145436c6c9cSStella Laurenzo     return py::cast(PyDialect(std::move(dialectDescriptor)));
146436c6c9cSStella Laurenzo   }
147436c6c9cSStella Laurenzo 
148436c6c9cSStella Laurenzo   // Create the custom implementation.
149436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
150436c6c9cSStella Laurenzo }
151436c6c9cSStella Laurenzo 
152436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
153436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
154436c6c9cSStella Laurenzo }
155436c6c9cSStella Laurenzo 
156436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
157436c6c9cSStella Laurenzo // Collections.
158436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
159436c6c9cSStella Laurenzo 
160436c6c9cSStella Laurenzo namespace {
161436c6c9cSStella Laurenzo 
162436c6c9cSStella Laurenzo class PyRegionIterator {
163436c6c9cSStella Laurenzo public:
164436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
165436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
166436c6c9cSStella Laurenzo 
167436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
168436c6c9cSStella Laurenzo 
169436c6c9cSStella Laurenzo   PyRegion dunderNext() {
170436c6c9cSStella Laurenzo     operation->checkValid();
171436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
172436c6c9cSStella Laurenzo       throw py::stop_iteration();
173436c6c9cSStella Laurenzo     }
174436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
175436c6c9cSStella Laurenzo     return PyRegion(operation, region);
176436c6c9cSStella Laurenzo   }
177436c6c9cSStella Laurenzo 
178436c6c9cSStella Laurenzo   static void bind(py::module &m) {
179436c6c9cSStella Laurenzo     py::class_<PyRegionIterator>(m, "RegionIterator")
180436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
181436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
182436c6c9cSStella Laurenzo   }
183436c6c9cSStella Laurenzo 
184436c6c9cSStella Laurenzo private:
185436c6c9cSStella Laurenzo   PyOperationRef operation;
186436c6c9cSStella Laurenzo   int nextIndex = 0;
187436c6c9cSStella Laurenzo };
188436c6c9cSStella Laurenzo 
189436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
190436c6c9cSStella Laurenzo /// with a sequence-like container.
191436c6c9cSStella Laurenzo class PyRegionList {
192436c6c9cSStella Laurenzo public:
193436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
194436c6c9cSStella Laurenzo 
195436c6c9cSStella Laurenzo   intptr_t dunderLen() {
196436c6c9cSStella Laurenzo     operation->checkValid();
197436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
198436c6c9cSStella Laurenzo   }
199436c6c9cSStella Laurenzo 
200436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
201436c6c9cSStella Laurenzo     // dunderLen checks validity.
202436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
203436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
204436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
205436c6c9cSStella Laurenzo     }
206436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
207436c6c9cSStella Laurenzo     return PyRegion(operation, region);
208436c6c9cSStella Laurenzo   }
209436c6c9cSStella Laurenzo 
210436c6c9cSStella Laurenzo   static void bind(py::module &m) {
211436c6c9cSStella Laurenzo     py::class_<PyRegionList>(m, "RegionSequence")
212436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
213436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
214436c6c9cSStella Laurenzo   }
215436c6c9cSStella Laurenzo 
216436c6c9cSStella Laurenzo private:
217436c6c9cSStella Laurenzo   PyOperationRef operation;
218436c6c9cSStella Laurenzo };
219436c6c9cSStella Laurenzo 
220436c6c9cSStella Laurenzo class PyBlockIterator {
221436c6c9cSStella Laurenzo public:
222436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
223436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
224436c6c9cSStella Laurenzo 
225436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
226436c6c9cSStella Laurenzo 
227436c6c9cSStella Laurenzo   PyBlock dunderNext() {
228436c6c9cSStella Laurenzo     operation->checkValid();
229436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
230436c6c9cSStella Laurenzo       throw py::stop_iteration();
231436c6c9cSStella Laurenzo     }
232436c6c9cSStella Laurenzo 
233436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
234436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
235436c6c9cSStella Laurenzo     return returnBlock;
236436c6c9cSStella Laurenzo   }
237436c6c9cSStella Laurenzo 
238436c6c9cSStella Laurenzo   static void bind(py::module &m) {
239436c6c9cSStella Laurenzo     py::class_<PyBlockIterator>(m, "BlockIterator")
240436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
241436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
242436c6c9cSStella Laurenzo   }
243436c6c9cSStella Laurenzo 
244436c6c9cSStella Laurenzo private:
245436c6c9cSStella Laurenzo   PyOperationRef operation;
246436c6c9cSStella Laurenzo   MlirBlock next;
247436c6c9cSStella Laurenzo };
248436c6c9cSStella Laurenzo 
249436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
250436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
251436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
252436c6c9cSStella Laurenzo class PyBlockList {
253436c6c9cSStella Laurenzo public:
254436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
255436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
256436c6c9cSStella Laurenzo 
257436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
258436c6c9cSStella Laurenzo     operation->checkValid();
259436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
260436c6c9cSStella Laurenzo   }
261436c6c9cSStella Laurenzo 
262436c6c9cSStella Laurenzo   intptr_t dunderLen() {
263436c6c9cSStella Laurenzo     operation->checkValid();
264436c6c9cSStella Laurenzo     intptr_t count = 0;
265436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
266436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
267436c6c9cSStella Laurenzo       count += 1;
268436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
269436c6c9cSStella Laurenzo     }
270436c6c9cSStella Laurenzo     return count;
271436c6c9cSStella Laurenzo   }
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
274436c6c9cSStella Laurenzo     operation->checkValid();
275436c6c9cSStella Laurenzo     if (index < 0) {
276436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
277436c6c9cSStella Laurenzo                        "attempt to access out of bounds block");
278436c6c9cSStella Laurenzo     }
279436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
280436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
281436c6c9cSStella Laurenzo       if (index == 0) {
282436c6c9cSStella Laurenzo         return PyBlock(operation, block);
283436c6c9cSStella Laurenzo       }
284436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
285436c6c9cSStella Laurenzo       index -= 1;
286436c6c9cSStella Laurenzo     }
287436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
288436c6c9cSStella Laurenzo   }
289436c6c9cSStella Laurenzo 
290436c6c9cSStella Laurenzo   PyBlock appendBlock(py::args pyArgTypes) {
291436c6c9cSStella Laurenzo     operation->checkValid();
292436c6c9cSStella Laurenzo     llvm::SmallVector<MlirType, 4> argTypes;
293436c6c9cSStella Laurenzo     argTypes.reserve(pyArgTypes.size());
294436c6c9cSStella Laurenzo     for (auto &pyArg : pyArgTypes) {
295436c6c9cSStella Laurenzo       argTypes.push_back(pyArg.cast<PyType &>());
296436c6c9cSStella Laurenzo     }
297436c6c9cSStella Laurenzo 
298436c6c9cSStella Laurenzo     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
299436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
300436c6c9cSStella Laurenzo     return PyBlock(operation, block);
301436c6c9cSStella Laurenzo   }
302436c6c9cSStella Laurenzo 
303436c6c9cSStella Laurenzo   static void bind(py::module &m) {
304436c6c9cSStella Laurenzo     py::class_<PyBlockList>(m, "BlockList")
305436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
306436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
307436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
308436c6c9cSStella Laurenzo         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
309436c6c9cSStella Laurenzo   }
310436c6c9cSStella Laurenzo 
311436c6c9cSStella Laurenzo private:
312436c6c9cSStella Laurenzo   PyOperationRef operation;
313436c6c9cSStella Laurenzo   MlirRegion region;
314436c6c9cSStella Laurenzo };
315436c6c9cSStella Laurenzo 
316436c6c9cSStella Laurenzo class PyOperationIterator {
317436c6c9cSStella Laurenzo public:
318436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
319436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
320436c6c9cSStella Laurenzo 
321436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
322436c6c9cSStella Laurenzo 
323436c6c9cSStella Laurenzo   py::object dunderNext() {
324436c6c9cSStella Laurenzo     parentOperation->checkValid();
325436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
326436c6c9cSStella Laurenzo       throw py::stop_iteration();
327436c6c9cSStella Laurenzo     }
328436c6c9cSStella Laurenzo 
329436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
330436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
331436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
332436c6c9cSStella Laurenzo     return returnOperation->createOpView();
333436c6c9cSStella Laurenzo   }
334436c6c9cSStella Laurenzo 
335436c6c9cSStella Laurenzo   static void bind(py::module &m) {
336436c6c9cSStella Laurenzo     py::class_<PyOperationIterator>(m, "OperationIterator")
337436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
338436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
339436c6c9cSStella Laurenzo   }
340436c6c9cSStella Laurenzo 
341436c6c9cSStella Laurenzo private:
342436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
343436c6c9cSStella Laurenzo   MlirOperation next;
344436c6c9cSStella Laurenzo };
345436c6c9cSStella Laurenzo 
346436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
347436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
348436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
349436c6c9cSStella Laurenzo /// by a block.
350436c6c9cSStella Laurenzo class PyOperationList {
351436c6c9cSStella Laurenzo public:
352436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
353436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
354436c6c9cSStella Laurenzo 
355436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
356436c6c9cSStella Laurenzo     parentOperation->checkValid();
357436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
358436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
359436c6c9cSStella Laurenzo   }
360436c6c9cSStella Laurenzo 
361436c6c9cSStella Laurenzo   intptr_t dunderLen() {
362436c6c9cSStella Laurenzo     parentOperation->checkValid();
363436c6c9cSStella Laurenzo     intptr_t count = 0;
364436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
365436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
366436c6c9cSStella Laurenzo       count += 1;
367436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
368436c6c9cSStella Laurenzo     }
369436c6c9cSStella Laurenzo     return count;
370436c6c9cSStella Laurenzo   }
371436c6c9cSStella Laurenzo 
372436c6c9cSStella Laurenzo   py::object dunderGetItem(intptr_t index) {
373436c6c9cSStella Laurenzo     parentOperation->checkValid();
374436c6c9cSStella Laurenzo     if (index < 0) {
375436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
376436c6c9cSStella Laurenzo                        "attempt to access out of bounds operation");
377436c6c9cSStella Laurenzo     }
378436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
379436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
380436c6c9cSStella Laurenzo       if (index == 0) {
381436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
382436c6c9cSStella Laurenzo             ->createOpView();
383436c6c9cSStella Laurenzo       }
384436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
385436c6c9cSStella Laurenzo       index -= 1;
386436c6c9cSStella Laurenzo     }
387436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError,
388436c6c9cSStella Laurenzo                      "attempt to access out of bounds operation");
389436c6c9cSStella Laurenzo   }
390436c6c9cSStella Laurenzo 
391436c6c9cSStella Laurenzo   static void bind(py::module &m) {
392436c6c9cSStella Laurenzo     py::class_<PyOperationList>(m, "OperationList")
393436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
394436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
395436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
396436c6c9cSStella Laurenzo   }
397436c6c9cSStella Laurenzo 
398436c6c9cSStella Laurenzo private:
399436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
400436c6c9cSStella Laurenzo   MlirBlock block;
401436c6c9cSStella Laurenzo };
402436c6c9cSStella Laurenzo 
403436c6c9cSStella Laurenzo } // namespace
404436c6c9cSStella Laurenzo 
405436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
406436c6c9cSStella Laurenzo // PyMlirContext
407436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
408436c6c9cSStella Laurenzo 
409436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
410436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
411436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
412436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
413436c6c9cSStella Laurenzo }
414436c6c9cSStella Laurenzo 
415436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
416436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
417436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
418436c6c9cSStella Laurenzo   // liveContexts.
419436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
420436c6c9cSStella Laurenzo   getLiveContexts().erase(context.ptr);
421436c6c9cSStella Laurenzo   mlirContextDestroy(context);
422436c6c9cSStella Laurenzo }
423436c6c9cSStella Laurenzo 
424436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() {
425436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
426436c6c9cSStella Laurenzo }
427436c6c9cSStella Laurenzo 
428436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) {
429436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
430436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
431436c6c9cSStella Laurenzo     throw py::error_already_set();
432436c6c9cSStella Laurenzo   return forContext(rawContext).releaseObject();
433436c6c9cSStella Laurenzo }
434436c6c9cSStella Laurenzo 
435436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() {
436436c6c9cSStella Laurenzo   MlirContext context = mlirContextCreate();
437436c6c9cSStella Laurenzo   mlirRegisterAllDialects(context);
438436c6c9cSStella Laurenzo   return new PyMlirContext(context);
439436c6c9cSStella Laurenzo }
440436c6c9cSStella Laurenzo 
441436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
442436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
443436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
444436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
445436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
446436c6c9cSStella Laurenzo     // Create.
447436c6c9cSStella Laurenzo     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
448436c6c9cSStella Laurenzo     py::object pyRef = py::cast(unownedContextWrapper);
449436c6c9cSStella Laurenzo     assert(pyRef && "cast to py::object failed");
450436c6c9cSStella Laurenzo     liveContexts[context.ptr] = unownedContextWrapper;
451436c6c9cSStella Laurenzo     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
452436c6c9cSStella Laurenzo   }
453436c6c9cSStella Laurenzo   // Use existing.
454436c6c9cSStella Laurenzo   py::object pyRef = py::cast(it->second);
455436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
456436c6c9cSStella Laurenzo }
457436c6c9cSStella Laurenzo 
458436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
459436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
460436c6c9cSStella Laurenzo   return liveContexts;
461436c6c9cSStella Laurenzo }
462436c6c9cSStella Laurenzo 
463436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
464436c6c9cSStella Laurenzo 
465436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
466436c6c9cSStella Laurenzo 
467436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
468436c6c9cSStella Laurenzo 
469436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() {
470436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushContext(*this);
471436c6c9cSStella Laurenzo }
472436c6c9cSStella Laurenzo 
473436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType,
474436c6c9cSStella Laurenzo                                 pybind11::object excVal,
475436c6c9cSStella Laurenzo                                 pybind11::object excTb) {
476436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
477436c6c9cSStella Laurenzo }
478436c6c9cSStella Laurenzo 
479436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
480436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
481436c6c9cSStella Laurenzo   if (!context) {
482436c6c9cSStella Laurenzo     throw SetPyError(
483436c6c9cSStella Laurenzo         PyExc_RuntimeError,
484436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
485436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
486436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
487436c6c9cSStella Laurenzo   }
488436c6c9cSStella Laurenzo   return *context;
489436c6c9cSStella Laurenzo }
490436c6c9cSStella Laurenzo 
491436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
492436c6c9cSStella Laurenzo // PyThreadContextEntry management
493436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
494436c6c9cSStella Laurenzo 
495436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
496436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
497436c6c9cSStella Laurenzo   return stack;
498436c6c9cSStella Laurenzo }
499436c6c9cSStella Laurenzo 
500436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
501436c6c9cSStella Laurenzo   auto &stack = getStack();
502436c6c9cSStella Laurenzo   if (stack.empty())
503436c6c9cSStella Laurenzo     return nullptr;
504436c6c9cSStella Laurenzo   return &stack.back();
505436c6c9cSStella Laurenzo }
506436c6c9cSStella Laurenzo 
507436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
508436c6c9cSStella Laurenzo                                 py::object insertionPoint,
509436c6c9cSStella Laurenzo                                 py::object location) {
510436c6c9cSStella Laurenzo   auto &stack = getStack();
511436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
512436c6c9cSStella Laurenzo                      std::move(location));
513436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
514436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
515436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
516436c6c9cSStella Laurenzo   if (stack.size() > 1) {
517436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
518436c6c9cSStella Laurenzo     auto &current = stack.back();
519436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
520436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
521436c6c9cSStella Laurenzo       if (!current.insertionPoint)
522436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
523436c6c9cSStella Laurenzo       if (!current.location)
524436c6c9cSStella Laurenzo         current.location = prev.location;
525436c6c9cSStella Laurenzo     }
526436c6c9cSStella Laurenzo   }
527436c6c9cSStella Laurenzo }
528436c6c9cSStella Laurenzo 
529436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
530436c6c9cSStella Laurenzo   if (!context)
531436c6c9cSStella Laurenzo     return nullptr;
532436c6c9cSStella Laurenzo   return py::cast<PyMlirContext *>(context);
533436c6c9cSStella Laurenzo }
534436c6c9cSStella Laurenzo 
535436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
536436c6c9cSStella Laurenzo   if (!insertionPoint)
537436c6c9cSStella Laurenzo     return nullptr;
538436c6c9cSStella Laurenzo   return py::cast<PyInsertionPoint *>(insertionPoint);
539436c6c9cSStella Laurenzo }
540436c6c9cSStella Laurenzo 
541436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
542436c6c9cSStella Laurenzo   if (!location)
543436c6c9cSStella Laurenzo     return nullptr;
544436c6c9cSStella Laurenzo   return py::cast<PyLocation *>(location);
545436c6c9cSStella Laurenzo }
546436c6c9cSStella Laurenzo 
547436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
548436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
549436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
550436c6c9cSStella Laurenzo }
551436c6c9cSStella Laurenzo 
552436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
553436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
554436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
555436c6c9cSStella Laurenzo }
556436c6c9cSStella Laurenzo 
557436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
558436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
559436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
560436c6c9cSStella Laurenzo }
561436c6c9cSStella Laurenzo 
562436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
563436c6c9cSStella Laurenzo   py::object contextObj = py::cast(context);
564436c6c9cSStella Laurenzo   push(FrameKind::Context, /*context=*/contextObj,
565436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
566436c6c9cSStella Laurenzo        /*location=*/py::object());
567436c6c9cSStella Laurenzo   return contextObj;
568436c6c9cSStella Laurenzo }
569436c6c9cSStella Laurenzo 
570436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
571436c6c9cSStella Laurenzo   auto &stack = getStack();
572436c6c9cSStella Laurenzo   if (stack.empty())
573436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
574436c6c9cSStella Laurenzo   auto &tos = stack.back();
575436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
576436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
577436c6c9cSStella Laurenzo   stack.pop_back();
578436c6c9cSStella Laurenzo }
579436c6c9cSStella Laurenzo 
580436c6c9cSStella Laurenzo py::object
581436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
582436c6c9cSStella Laurenzo   py::object contextObj =
583436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
584436c6c9cSStella Laurenzo   py::object insertionPointObj = py::cast(insertionPoint);
585436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
586436c6c9cSStella Laurenzo        /*context=*/contextObj,
587436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
588436c6c9cSStella Laurenzo        /*location=*/py::object());
589436c6c9cSStella Laurenzo   return insertionPointObj;
590436c6c9cSStella Laurenzo }
591436c6c9cSStella Laurenzo 
592436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
593436c6c9cSStella Laurenzo   auto &stack = getStack();
594436c6c9cSStella Laurenzo   if (stack.empty())
595436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
596436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
597436c6c9cSStella Laurenzo   auto &tos = stack.back();
598436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
599436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
600436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
601436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
602436c6c9cSStella Laurenzo   stack.pop_back();
603436c6c9cSStella Laurenzo }
604436c6c9cSStella Laurenzo 
605436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
606436c6c9cSStella Laurenzo   py::object contextObj = location.getContext().getObject();
607436c6c9cSStella Laurenzo   py::object locationObj = py::cast(location);
608436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
609436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
610436c6c9cSStella Laurenzo        /*location=*/locationObj);
611436c6c9cSStella Laurenzo   return locationObj;
612436c6c9cSStella Laurenzo }
613436c6c9cSStella Laurenzo 
614436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
615436c6c9cSStella Laurenzo   auto &stack = getStack();
616436c6c9cSStella Laurenzo   if (stack.empty())
617436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
618436c6c9cSStella Laurenzo   auto &tos = stack.back();
619436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
620436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
621436c6c9cSStella Laurenzo   stack.pop_back();
622436c6c9cSStella Laurenzo }
623436c6c9cSStella Laurenzo 
624436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
625436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects
626436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
627436c6c9cSStella Laurenzo 
628436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
629436c6c9cSStella Laurenzo                                          bool attrError) {
630436c6c9cSStella Laurenzo   // If the "std" dialect was asked for, substitute the empty namespace :(
631436c6c9cSStella Laurenzo   static const std::string emptyKey;
632436c6c9cSStella Laurenzo   const std::string *canonKey = key == "std" ? &emptyKey : &key;
633436c6c9cSStella Laurenzo   MlirDialect dialect = mlirContextGetOrLoadDialect(
634436c6c9cSStella Laurenzo       getContext()->get(), {canonKey->data(), canonKey->size()});
635436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
636436c6c9cSStella Laurenzo     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
637436c6c9cSStella Laurenzo                      Twine("Dialect '") + key + "' not found");
638436c6c9cSStella Laurenzo   }
639436c6c9cSStella Laurenzo   return dialect;
640436c6c9cSStella Laurenzo }
641436c6c9cSStella Laurenzo 
642436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
643436c6c9cSStella Laurenzo // PyLocation
644436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
645436c6c9cSStella Laurenzo 
646436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() {
647436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
648436c6c9cSStella Laurenzo }
649436c6c9cSStella Laurenzo 
650436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) {
651436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
652436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
653436c6c9cSStella Laurenzo     throw py::error_already_set();
654436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
655436c6c9cSStella Laurenzo                     rawLoc);
656436c6c9cSStella Laurenzo }
657436c6c9cSStella Laurenzo 
658436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() {
659436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushLocation(*this);
660436c6c9cSStella Laurenzo }
661436c6c9cSStella Laurenzo 
662436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal,
663436c6c9cSStella Laurenzo                              py::object excTb) {
664436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
665436c6c9cSStella Laurenzo }
666436c6c9cSStella Laurenzo 
667436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
668436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
669436c6c9cSStella Laurenzo   if (!location) {
670436c6c9cSStella Laurenzo     throw SetPyError(
671436c6c9cSStella Laurenzo         PyExc_RuntimeError,
672436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
673436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
674436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
675436c6c9cSStella Laurenzo   }
676436c6c9cSStella Laurenzo   return *location;
677436c6c9cSStella Laurenzo }
678436c6c9cSStella Laurenzo 
679436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
680436c6c9cSStella Laurenzo // PyModule
681436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
682436c6c9cSStella Laurenzo 
683436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
684436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
685436c6c9cSStella Laurenzo 
686436c6c9cSStella Laurenzo PyModule::~PyModule() {
687436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
688436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
689436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
690436c6c9cSStella Laurenzo          "destroying module not in live map");
691436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
692436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
693436c6c9cSStella Laurenzo }
694436c6c9cSStella Laurenzo 
695436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
696436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
697436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
698436c6c9cSStella Laurenzo 
699436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
700436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
701436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
702436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
703436c6c9cSStella Laurenzo     // Create.
704436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
705436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
706436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
707436c6c9cSStella Laurenzo     // Just be explicit.
708436c6c9cSStella Laurenzo     py::object pyRef =
709436c6c9cSStella Laurenzo         py::cast(unownedModule, py::return_value_policy::take_ownership);
710436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
711436c6c9cSStella Laurenzo     liveModules[module.ptr] =
712436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
713436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
714436c6c9cSStella Laurenzo   }
715436c6c9cSStella Laurenzo   // Use existing.
716436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
717436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
718436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
719436c6c9cSStella Laurenzo }
720436c6c9cSStella Laurenzo 
721436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) {
722436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
723436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
724436c6c9cSStella Laurenzo     throw py::error_already_set();
725436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
726436c6c9cSStella Laurenzo }
727436c6c9cSStella Laurenzo 
728436c6c9cSStella Laurenzo py::object PyModule::getCapsule() {
729436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
730436c6c9cSStella Laurenzo }
731436c6c9cSStella Laurenzo 
732436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
733436c6c9cSStella Laurenzo // PyOperation
734436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
735436c6c9cSStella Laurenzo 
736436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
737436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
738436c6c9cSStella Laurenzo 
739436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
740436c6c9cSStella Laurenzo   auto &liveOperations = getContext()->liveOperations;
741436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 1 &&
742436c6c9cSStella Laurenzo          "destroying operation not in live map");
743436c6c9cSStella Laurenzo   liveOperations.erase(operation.ptr);
744436c6c9cSStella Laurenzo   if (!isAttached()) {
745436c6c9cSStella Laurenzo     mlirOperationDestroy(operation);
746436c6c9cSStella Laurenzo   }
747436c6c9cSStella Laurenzo }
748436c6c9cSStella Laurenzo 
749436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
750436c6c9cSStella Laurenzo                                            MlirOperation operation,
751436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
752436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
753436c6c9cSStella Laurenzo   // Create.
754436c6c9cSStella Laurenzo   PyOperation *unownedOperation =
755436c6c9cSStella Laurenzo       new PyOperation(std::move(contextRef), operation);
756436c6c9cSStella Laurenzo   // Note that the default return value policy on cast is automatic_reference,
757436c6c9cSStella Laurenzo   // which does not take ownership (delete will not be called).
758436c6c9cSStella Laurenzo   // Just be explicit.
759436c6c9cSStella Laurenzo   py::object pyRef =
760436c6c9cSStella Laurenzo       py::cast(unownedOperation, py::return_value_policy::take_ownership);
761436c6c9cSStella Laurenzo   unownedOperation->handle = pyRef;
762436c6c9cSStella Laurenzo   if (parentKeepAlive) {
763436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
764436c6c9cSStella Laurenzo   }
765436c6c9cSStella Laurenzo   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
766436c6c9cSStella Laurenzo   return PyOperationRef(unownedOperation, std::move(pyRef));
767436c6c9cSStella Laurenzo }
768436c6c9cSStella Laurenzo 
769436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
770436c6c9cSStella Laurenzo                                          MlirOperation operation,
771436c6c9cSStella Laurenzo                                          py::object parentKeepAlive) {
772436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
773436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
774436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
775436c6c9cSStella Laurenzo     // Create.
776436c6c9cSStella Laurenzo     return createInstance(std::move(contextRef), operation,
777436c6c9cSStella Laurenzo                           std::move(parentKeepAlive));
778436c6c9cSStella Laurenzo   }
779436c6c9cSStella Laurenzo   // Use existing.
780436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
781436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
782436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
783436c6c9cSStella Laurenzo }
784436c6c9cSStella Laurenzo 
785436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
786436c6c9cSStella Laurenzo                                            MlirOperation operation,
787436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
788436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
789436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
790436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
791436c6c9cSStella Laurenzo   (void)liveOperations;
792436c6c9cSStella Laurenzo 
793436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
794436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
795436c6c9cSStella Laurenzo   created->attached = false;
796436c6c9cSStella Laurenzo   return created;
797436c6c9cSStella Laurenzo }
798436c6c9cSStella Laurenzo 
799436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
800436c6c9cSStella Laurenzo   if (!valid) {
801436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
802436c6c9cSStella Laurenzo   }
803436c6c9cSStella Laurenzo }
804436c6c9cSStella Laurenzo 
805436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary,
806436c6c9cSStella Laurenzo                             llvm::Optional<int64_t> largeElementsLimit,
807436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
808436c6c9cSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope) {
809436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
810436c6c9cSStella Laurenzo   operation.checkValid();
811436c6c9cSStella Laurenzo   if (fileObject.is_none())
812436c6c9cSStella Laurenzo     fileObject = py::module::import("sys").attr("stdout");
813436c6c9cSStella Laurenzo 
814436c6c9cSStella Laurenzo   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
815436c6c9cSStella Laurenzo     fileObject.attr("write")("// Verification failed, printing generic form\n");
816436c6c9cSStella Laurenzo     printGenericOpForm = true;
817436c6c9cSStella Laurenzo   }
818436c6c9cSStella Laurenzo 
819436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
820436c6c9cSStella Laurenzo   if (largeElementsLimit)
821436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
822436c6c9cSStella Laurenzo   if (enableDebugInfo)
823436c6c9cSStella Laurenzo     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
824436c6c9cSStella Laurenzo   if (printGenericOpForm)
825436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
826436c6c9cSStella Laurenzo 
827436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
828436c6c9cSStella Laurenzo   py::gil_scoped_release();
829436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
830436c6c9cSStella Laurenzo                               accum.getUserData());
831436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
832436c6c9cSStella Laurenzo }
833436c6c9cSStella Laurenzo 
834436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary,
835436c6c9cSStella Laurenzo                                    llvm::Optional<int64_t> largeElementsLimit,
836436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
837436c6c9cSStella Laurenzo                                    bool printGenericOpForm,
838436c6c9cSStella Laurenzo                                    bool useLocalScope) {
839436c6c9cSStella Laurenzo   py::object fileObject;
840436c6c9cSStella Laurenzo   if (binary) {
841436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("BytesIO")();
842436c6c9cSStella Laurenzo   } else {
843436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("StringIO")();
844436c6c9cSStella Laurenzo   }
845436c6c9cSStella Laurenzo   print(fileObject, /*binary=*/binary,
846436c6c9cSStella Laurenzo         /*largeElementsLimit=*/largeElementsLimit,
847436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
848436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
849436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
850436c6c9cSStella Laurenzo         /*useLocalScope=*/useLocalScope);
851436c6c9cSStella Laurenzo 
852436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
853436c6c9cSStella Laurenzo }
854436c6c9cSStella Laurenzo 
855436c6c9cSStella Laurenzo PyOperationRef PyOperation::getParentOperation() {
856436c6c9cSStella Laurenzo   if (!isAttached())
857436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
858436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
859436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
860436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
861436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
862436c6c9cSStella Laurenzo }
863436c6c9cSStella Laurenzo 
864436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
865436c6c9cSStella Laurenzo   PyOperationRef parentOperation = getParentOperation();
866436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
867436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
868436c6c9cSStella Laurenzo   return PyBlock{std::move(parentOperation), block};
869436c6c9cSStella Laurenzo }
870436c6c9cSStella Laurenzo 
8710126e906SJohn Demme py::object PyOperation::getCapsule() {
8720126e906SJohn Demme   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
8730126e906SJohn Demme }
8740126e906SJohn Demme 
8750126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) {
8760126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
8770126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
8780126e906SJohn Demme     throw py::error_already_set();
8790126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
8800126e906SJohn Demme   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
8810126e906SJohn Demme       .releaseObject();
8820126e906SJohn Demme }
8830126e906SJohn Demme 
884436c6c9cSStella Laurenzo py::object PyOperation::create(
885436c6c9cSStella Laurenzo     std::string name, llvm::Optional<std::vector<PyType *>> results,
886436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyValue *>> operands,
887436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
888436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
889436c6c9cSStella Laurenzo     DefaultingPyLocation location, py::object maybeIp) {
890436c6c9cSStella Laurenzo   llvm::SmallVector<MlirValue, 4> mlirOperands;
891436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
892436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
893436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
894436c6c9cSStella Laurenzo 
895436c6c9cSStella Laurenzo   // General parameter validation.
896436c6c9cSStella Laurenzo   if (regions < 0)
897436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
898436c6c9cSStella Laurenzo 
899436c6c9cSStella Laurenzo   // Unpack/validate operands.
900436c6c9cSStella Laurenzo   if (operands) {
901436c6c9cSStella Laurenzo     mlirOperands.reserve(operands->size());
902436c6c9cSStella Laurenzo     for (PyValue *operand : *operands) {
903436c6c9cSStella Laurenzo       if (!operand)
904436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
905436c6c9cSStella Laurenzo       mlirOperands.push_back(operand->get());
906436c6c9cSStella Laurenzo     }
907436c6c9cSStella Laurenzo   }
908436c6c9cSStella Laurenzo 
909436c6c9cSStella Laurenzo   // Unpack/validate results.
910436c6c9cSStella Laurenzo   if (results) {
911436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
912436c6c9cSStella Laurenzo     for (PyType *result : *results) {
913436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
914436c6c9cSStella Laurenzo       if (!result)
915436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "result type cannot be None");
916436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
917436c6c9cSStella Laurenzo     }
918436c6c9cSStella Laurenzo   }
919436c6c9cSStella Laurenzo   // Unpack/validate attributes.
920436c6c9cSStella Laurenzo   if (attributes) {
921436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
922436c6c9cSStella Laurenzo     for (auto &it : *attributes) {
923436c6c9cSStella Laurenzo       std::string key;
924436c6c9cSStella Laurenzo       try {
925436c6c9cSStella Laurenzo         key = it.first.cast<std::string>();
926436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
927436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
928436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
929436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
930436c6c9cSStella Laurenzo         throw py::cast_error(msg);
931436c6c9cSStella Laurenzo       }
932436c6c9cSStella Laurenzo       try {
933436c6c9cSStella Laurenzo         auto &attribute = it.second.cast<PyAttribute &>();
934436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
935436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
936436c6c9cSStella Laurenzo       } catch (py::reference_cast_error &) {
937436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
938436c6c9cSStella Laurenzo         std::string msg =
939436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
940436c6c9cSStella Laurenzo             "\" when attempting to create the operation \"" + name + "\"";
941436c6c9cSStella Laurenzo         throw py::cast_error(msg);
942436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
943436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute value for the key \"" + key +
944436c6c9cSStella Laurenzo                           "\" when attempting to create the operation \"" +
945436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
946436c6c9cSStella Laurenzo         throw py::cast_error(msg);
947436c6c9cSStella Laurenzo       }
948436c6c9cSStella Laurenzo     }
949436c6c9cSStella Laurenzo   }
950436c6c9cSStella Laurenzo   // Unpack/validate successors.
951436c6c9cSStella Laurenzo   if (successors) {
952436c6c9cSStella Laurenzo     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
954436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
955436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
956436c6c9cSStella Laurenzo       if (!successor)
957436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
958436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
959436c6c9cSStella Laurenzo     }
960436c6c9cSStella Laurenzo   }
961436c6c9cSStella Laurenzo 
962436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
963436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
964436c6c9cSStella Laurenzo   MlirOperationState state =
965436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
966436c6c9cSStella Laurenzo   if (!mlirOperands.empty())
967436c6c9cSStella Laurenzo     mlirOperationStateAddOperands(&state, mlirOperands.size(),
968436c6c9cSStella Laurenzo                                   mlirOperands.data());
969436c6c9cSStella Laurenzo   if (!mlirResults.empty())
970436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
971436c6c9cSStella Laurenzo                                  mlirResults.data());
972436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
973436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
974436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
975436c6c9cSStella Laurenzo     // on.
976436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
977436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
978436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
979436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
980436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
981436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
982436c6c9cSStella Laurenzo           it.second));
983436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
984436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
985436c6c9cSStella Laurenzo   }
986436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
987436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
988436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
989436c6c9cSStella Laurenzo   if (regions) {
990436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
991436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
992436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
993436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
994436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
995436c6c9cSStella Laurenzo                                       mlirRegions.data());
996436c6c9cSStella Laurenzo   }
997436c6c9cSStella Laurenzo 
998436c6c9cSStella Laurenzo   // Construct the operation.
999436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1000436c6c9cSStella Laurenzo   PyOperationRef created =
1001436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1002436c6c9cSStella Laurenzo 
1003436c6c9cSStella Laurenzo   // InsertPoint active?
1004436c6c9cSStella Laurenzo   if (!maybeIp.is(py::cast(false))) {
1005436c6c9cSStella Laurenzo     PyInsertionPoint *ip;
1006436c6c9cSStella Laurenzo     if (maybeIp.is_none()) {
1007436c6c9cSStella Laurenzo       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1008436c6c9cSStella Laurenzo     } else {
1009436c6c9cSStella Laurenzo       ip = py::cast<PyInsertionPoint *>(maybeIp);
1010436c6c9cSStella Laurenzo     }
1011436c6c9cSStella Laurenzo     if (ip)
1012436c6c9cSStella Laurenzo       ip->insert(*created.get());
1013436c6c9cSStella Laurenzo   }
1014436c6c9cSStella Laurenzo 
1015436c6c9cSStella Laurenzo   return created->createOpView();
1016436c6c9cSStella Laurenzo }
1017436c6c9cSStella Laurenzo 
1018436c6c9cSStella Laurenzo py::object PyOperation::createOpView() {
1019436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1020436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1021436c6c9cSStella Laurenzo   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1022436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1023436c6c9cSStella Laurenzo   if (opViewClass)
1024436c6c9cSStella Laurenzo     return (*opViewClass)(getRef().getObject());
1025436c6c9cSStella Laurenzo   return py::cast(PyOpView(getRef().getObject()));
1026436c6c9cSStella Laurenzo }
1027436c6c9cSStella Laurenzo 
1028436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1029436c6c9cSStella Laurenzo // PyOpView
1030436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1031436c6c9cSStella Laurenzo 
1032436c6c9cSStella Laurenzo py::object
1033436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1034436c6c9cSStella Laurenzo                        py::list operandList,
1035436c6c9cSStella Laurenzo                        llvm::Optional<py::dict> attributes,
1036436c6c9cSStella Laurenzo                        llvm::Optional<std::vector<PyBlock *>> successors,
1037436c6c9cSStella Laurenzo                        llvm::Optional<int> regions,
1038436c6c9cSStella Laurenzo                        DefaultingPyLocation location, py::object maybeIp) {
1039436c6c9cSStella Laurenzo   PyMlirContextRef context = location->getContext();
1040436c6c9cSStella Laurenzo   // Class level operation construction metadata.
1041436c6c9cSStella Laurenzo   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1042436c6c9cSStella Laurenzo   // Operand and result segment specs are either none, which does no
1043436c6c9cSStella Laurenzo   // variadic unpacking, or a list of ints with segment sizes, where each
1044436c6c9cSStella Laurenzo   // element is either a positive number (typically 1 for a scalar) or -1 to
1045436c6c9cSStella Laurenzo   // indicate that it is derived from the length of the same-indexed operand
1046436c6c9cSStella Laurenzo   // or result (implying that it is a list at that position).
1047436c6c9cSStella Laurenzo   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1048436c6c9cSStella Laurenzo   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1049436c6c9cSStella Laurenzo 
10508d05a288SStella Laurenzo   std::vector<uint32_t> operandSegmentLengths;
10518d05a288SStella Laurenzo   std::vector<uint32_t> resultSegmentLengths;
1052436c6c9cSStella Laurenzo 
1053436c6c9cSStella Laurenzo   // Validate/determine region count.
1054436c6c9cSStella Laurenzo   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1055436c6c9cSStella Laurenzo   int opMinRegionCount = std::get<0>(opRegionSpec);
1056436c6c9cSStella Laurenzo   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1057436c6c9cSStella Laurenzo   if (!regions) {
1058436c6c9cSStella Laurenzo     regions = opMinRegionCount;
1059436c6c9cSStella Laurenzo   }
1060436c6c9cSStella Laurenzo   if (*regions < opMinRegionCount) {
1061436c6c9cSStella Laurenzo     throw py::value_error(
1062436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1063436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1064436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1065436c6c9cSStella Laurenzo             .str());
1066436c6c9cSStella Laurenzo   }
1067436c6c9cSStella Laurenzo   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1068436c6c9cSStella Laurenzo     throw py::value_error(
1069436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1070436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1071436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1072436c6c9cSStella Laurenzo             .str());
1073436c6c9cSStella Laurenzo   }
1074436c6c9cSStella Laurenzo 
1075436c6c9cSStella Laurenzo   // Unpack results.
1076436c6c9cSStella Laurenzo   std::vector<PyType *> resultTypes;
1077436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1078436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1079436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1080436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(resultTypeList)) {
1081436c6c9cSStella Laurenzo       try {
1082436c6c9cSStella Laurenzo         resultTypes.push_back(py::cast<PyType *>(it.value()));
1083436c6c9cSStella Laurenzo         if (!resultTypes.back())
1084436c6c9cSStella Laurenzo           throw py::cast_error();
1085436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1086436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Result ") +
1087436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1088436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1089436c6c9cSStella Laurenzo                                   .str());
1090436c6c9cSStella Laurenzo       }
1091436c6c9cSStella Laurenzo     }
1092436c6c9cSStella Laurenzo   } else {
1093436c6c9cSStella Laurenzo     // Sized result unpacking.
1094436c6c9cSStella Laurenzo     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1095436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1096436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1097436c6c9cSStella Laurenzo                              "\" requires " +
1098436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1099436c6c9cSStella Laurenzo                              "result segments but was provided " +
1100436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1101436c6c9cSStella Laurenzo                                 .str());
1102436c6c9cSStella Laurenzo     }
1103436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1104436c6c9cSStella Laurenzo     for (auto it :
1105436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1106436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1107436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1108436c6c9cSStella Laurenzo         // Unpack unary element.
1109436c6c9cSStella Laurenzo         try {
1110436c6c9cSStella Laurenzo           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1111436c6c9cSStella Laurenzo           if (resultType) {
1112436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1113436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1114436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1115436c6c9cSStella Laurenzo             // Allowed to be optional.
1116436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1117436c6c9cSStella Laurenzo           } else {
1118436c6c9cSStella Laurenzo             throw py::cast_error("was None and result is not optional");
1119436c6c9cSStella Laurenzo           }
1120436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1121436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1122436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1123436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1124436c6c9cSStella Laurenzo                                  ")")
1125436c6c9cSStella Laurenzo                                     .str());
1126436c6c9cSStella Laurenzo         }
1127436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1128436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1129436c6c9cSStella Laurenzo         try {
1130436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1131436c6c9cSStella Laurenzo             // Treat it as an empty list.
1132436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1133436c6c9cSStella Laurenzo           } else {
1134436c6c9cSStella Laurenzo             // Unpack the list.
1135436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1136436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1137436c6c9cSStella Laurenzo               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1138436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1139436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1140436c6c9cSStella Laurenzo               }
1141436c6c9cSStella Laurenzo             }
1142436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(segment.size());
1143436c6c9cSStella Laurenzo           }
1144436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1145436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1146436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1147436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1148436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1149436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1150436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1151436c6c9cSStella Laurenzo                                  err.what() + ")")
1152436c6c9cSStella Laurenzo                                     .str());
1153436c6c9cSStella Laurenzo         }
1154436c6c9cSStella Laurenzo       } else {
1155436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1156436c6c9cSStella Laurenzo       }
1157436c6c9cSStella Laurenzo     }
1158436c6c9cSStella Laurenzo   }
1159436c6c9cSStella Laurenzo 
1160436c6c9cSStella Laurenzo   // Unpack operands.
1161436c6c9cSStella Laurenzo   std::vector<PyValue *> operands;
1162436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1163436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1164436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1165436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(operandList)) {
1166436c6c9cSStella Laurenzo       try {
1167436c6c9cSStella Laurenzo         operands.push_back(py::cast<PyValue *>(it.value()));
1168436c6c9cSStella Laurenzo         if (!operands.back())
1169436c6c9cSStella Laurenzo           throw py::cast_error();
1170436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1171436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Operand ") +
1172436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1173436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1174436c6c9cSStella Laurenzo                                   .str());
1175436c6c9cSStella Laurenzo       }
1176436c6c9cSStella Laurenzo     }
1177436c6c9cSStella Laurenzo   } else {
1178436c6c9cSStella Laurenzo     // Sized operand unpacking.
1179436c6c9cSStella Laurenzo     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1180436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1181436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1182436c6c9cSStella Laurenzo                              "\" requires " +
1183436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1184436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1185436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1186436c6c9cSStella Laurenzo                                 .str());
1187436c6c9cSStella Laurenzo     }
1188436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1189436c6c9cSStella Laurenzo     for (auto it :
1190436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1191436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1192436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1193436c6c9cSStella Laurenzo         // Unpack unary element.
1194436c6c9cSStella Laurenzo         try {
1195436c6c9cSStella Laurenzo           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1196436c6c9cSStella Laurenzo           if (operandValue) {
1197436c6c9cSStella Laurenzo             operands.push_back(operandValue);
1198436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(1);
1199436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1200436c6c9cSStella Laurenzo             // Allowed to be optional.
1201436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1202436c6c9cSStella Laurenzo           } else {
1203436c6c9cSStella Laurenzo             throw py::cast_error("was None and operand is not optional");
1204436c6c9cSStella Laurenzo           }
1205436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1206436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1207436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1208436c6c9cSStella Laurenzo                                  name + "\" must be a Value (" + err.what() +
1209436c6c9cSStella Laurenzo                                  ")")
1210436c6c9cSStella Laurenzo                                     .str());
1211436c6c9cSStella Laurenzo         }
1212436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1213436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1214436c6c9cSStella Laurenzo         try {
1215436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1216436c6c9cSStella Laurenzo             // Treat it as an empty list.
1217436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1218436c6c9cSStella Laurenzo           } else {
1219436c6c9cSStella Laurenzo             // Unpack the list.
1220436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1221436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1222436c6c9cSStella Laurenzo               operands.push_back(py::cast<PyValue *>(segmentItem));
1223436c6c9cSStella Laurenzo               if (!operands.back()) {
1224436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1225436c6c9cSStella Laurenzo               }
1226436c6c9cSStella Laurenzo             }
1227436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(segment.size());
1228436c6c9cSStella Laurenzo           }
1229436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1230436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1231436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1232436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1233436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1234436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1235436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
1236436c6c9cSStella Laurenzo                                  err.what() + ")")
1237436c6c9cSStella Laurenzo                                     .str());
1238436c6c9cSStella Laurenzo         }
1239436c6c9cSStella Laurenzo       } else {
1240436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1241436c6c9cSStella Laurenzo       }
1242436c6c9cSStella Laurenzo     }
1243436c6c9cSStella Laurenzo   }
1244436c6c9cSStella Laurenzo 
1245436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
1246436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1247436c6c9cSStella Laurenzo     // Dup.
1248436c6c9cSStella Laurenzo     if (attributes) {
1249436c6c9cSStella Laurenzo       attributes = py::dict(*attributes);
1250436c6c9cSStella Laurenzo     } else {
1251436c6c9cSStella Laurenzo       attributes = py::dict();
1252436c6c9cSStella Laurenzo     }
1253436c6c9cSStella Laurenzo     if (attributes->contains("result_segment_sizes") ||
1254436c6c9cSStella Laurenzo         attributes->contains("operand_segment_sizes")) {
1255436c6c9cSStella Laurenzo       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1256436c6c9cSStella Laurenzo                             "'operand_segment_sizes' attribute is unsupported. "
1257436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
1258436c6c9cSStella Laurenzo     }
1259436c6c9cSStella Laurenzo 
1260436c6c9cSStella Laurenzo     // Add result_segment_sizes attribute.
1261436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
1262436c6c9cSStella Laurenzo       int64_t size = resultSegmentLengths.size();
12638d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
12648d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1265436c6c9cSStella Laurenzo           resultSegmentLengths.size(), resultSegmentLengths.data());
1266436c6c9cSStella Laurenzo       (*attributes)["result_segment_sizes"] =
1267436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1268436c6c9cSStella Laurenzo     }
1269436c6c9cSStella Laurenzo 
1270436c6c9cSStella Laurenzo     // Add operand_segment_sizes attribute.
1271436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
1272436c6c9cSStella Laurenzo       int64_t size = operandSegmentLengths.size();
12738d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
12748d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1275436c6c9cSStella Laurenzo           operandSegmentLengths.size(), operandSegmentLengths.data());
1276436c6c9cSStella Laurenzo       (*attributes)["operand_segment_sizes"] =
1277436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1278436c6c9cSStella Laurenzo     }
1279436c6c9cSStella Laurenzo   }
1280436c6c9cSStella Laurenzo 
1281436c6c9cSStella Laurenzo   // Delegate to create.
1282436c6c9cSStella Laurenzo   return PyOperation::create(std::move(name),
1283436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
1284436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
1285436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
1286436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
1287436c6c9cSStella Laurenzo                              /*regions=*/*regions, location, maybeIp);
1288436c6c9cSStella Laurenzo }
1289436c6c9cSStella Laurenzo 
1290436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject)
1291436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
1292436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
1293436c6c9cSStella Laurenzo     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1294436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
1295436c6c9cSStella Laurenzo 
1296436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) {
1297436c6c9cSStella Laurenzo   // This is... a little gross. The typical pattern is to have a pure python
1298436c6c9cSStella Laurenzo   // class that extends OpView like:
1299436c6c9cSStella Laurenzo   //   class AddFOp(_cext.ir.OpView):
1300436c6c9cSStella Laurenzo   //     def __init__(self, loc, lhs, rhs):
1301436c6c9cSStella Laurenzo   //       operation = loc.context.create_operation(
1302436c6c9cSStella Laurenzo   //           "addf", lhs, rhs, results=[lhs.type])
1303436c6c9cSStella Laurenzo   //       super().__init__(operation)
1304436c6c9cSStella Laurenzo   //
1305436c6c9cSStella Laurenzo   // I.e. The goal of the user facing type is to provide a nice constructor
1306436c6c9cSStella Laurenzo   // that has complete freedom for the op under construction. This is at odds
1307436c6c9cSStella Laurenzo   // with our other desire to sometimes create this object by just passing an
1308436c6c9cSStella Laurenzo   // operation (to initialize the base class). We could do *arg and **kwargs
1309436c6c9cSStella Laurenzo   // munging to try to make it work, but instead, we synthesize a new class
1310436c6c9cSStella Laurenzo   // on the fly which extends this user class (AddFOp in this example) and
1311436c6c9cSStella Laurenzo   // *give it* the base class's __init__ method, thus bypassing the
1312436c6c9cSStella Laurenzo   // intermediate subclass's __init__ method entirely. While slightly,
1313436c6c9cSStella Laurenzo   // underhanded, this is safe/legal because the type hierarchy has not changed
1314436c6c9cSStella Laurenzo   // (we just added a new leaf) and we aren't mucking around with __new__.
1315436c6c9cSStella Laurenzo   // Typically, this new class will be stored on the original as "_Raw" and will
1316436c6c9cSStella Laurenzo   // be used for casts and other things that need a variant of the class that
1317436c6c9cSStella Laurenzo   // is initialized purely from an operation.
1318436c6c9cSStella Laurenzo   py::object parentMetaclass =
1319436c6c9cSStella Laurenzo       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1320436c6c9cSStella Laurenzo   py::dict attributes;
1321436c6c9cSStella Laurenzo   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1322436c6c9cSStella Laurenzo   // now.
1323436c6c9cSStella Laurenzo   //   auto opViewType = py::type::of<PyOpView>();
1324436c6c9cSStella Laurenzo   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1325436c6c9cSStella Laurenzo   attributes["__init__"] = opViewType.attr("__init__");
1326436c6c9cSStella Laurenzo   py::str origName = userClass.attr("__name__");
1327436c6c9cSStella Laurenzo   py::str newName = py::str("_") + origName;
1328436c6c9cSStella Laurenzo   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1329436c6c9cSStella Laurenzo }
1330436c6c9cSStella Laurenzo 
1331436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1332436c6c9cSStella Laurenzo // PyInsertionPoint.
1333436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1334436c6c9cSStella Laurenzo 
1335436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1336436c6c9cSStella Laurenzo 
1337436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1338436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
1339436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
1340436c6c9cSStella Laurenzo 
1341436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1342436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
1343436c6c9cSStella Laurenzo   if (operation.isAttached())
1344436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError,
1345436c6c9cSStella Laurenzo                      "Attempt to insert operation that is already attached");
1346436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
1347436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
1348436c6c9cSStella Laurenzo   if (refOperation) {
1349436c6c9cSStella Laurenzo     // Insert before operation.
1350436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
1351436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
1352436c6c9cSStella Laurenzo   } else {
1353436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
1354436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
1355436c6c9cSStella Laurenzo     // failures later).
1356436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1357436c6c9cSStella Laurenzo       throw py::index_error("Cannot insert operation at the end of a block "
1358436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
1359436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
1360436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
1361436c6c9cSStella Laurenzo     }
1362436c6c9cSStella Laurenzo   }
1363436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1364436c6c9cSStella Laurenzo   operation.setAttached();
1365436c6c9cSStella Laurenzo }
1366436c6c9cSStella Laurenzo 
1367436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1368436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1369436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
1370436c6c9cSStella Laurenzo     // Just insert at end.
1371436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
1372436c6c9cSStella Laurenzo   }
1373436c6c9cSStella Laurenzo 
1374436c6c9cSStella Laurenzo   // Insert before first op.
1375436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
1376436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
1377436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
1378436c6c9cSStella Laurenzo }
1379436c6c9cSStella Laurenzo 
1380436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1381436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1382436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
1383436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1384436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1385436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
1386436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1387436c6c9cSStella Laurenzo }
1388436c6c9cSStella Laurenzo 
1389436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() {
1390436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushInsertionPoint(*this);
1391436c6c9cSStella Laurenzo }
1392436c6c9cSStella Laurenzo 
1393436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType,
1394436c6c9cSStella Laurenzo                                    pybind11::object excVal,
1395436c6c9cSStella Laurenzo                                    pybind11::object excTb) {
1396436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
1397436c6c9cSStella Laurenzo }
1398436c6c9cSStella Laurenzo 
1399436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1400436c6c9cSStella Laurenzo // PyAttribute.
1401436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1402436c6c9cSStella Laurenzo 
1403436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) {
1404436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
1405436c6c9cSStella Laurenzo }
1406436c6c9cSStella Laurenzo 
1407436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() {
1408436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1409436c6c9cSStella Laurenzo }
1410436c6c9cSStella Laurenzo 
1411436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1412436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1413436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
1414436c6c9cSStella Laurenzo     throw py::error_already_set();
1415436c6c9cSStella Laurenzo   return PyAttribute(
1416436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1417436c6c9cSStella Laurenzo }
1418436c6c9cSStella Laurenzo 
1419436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1420436c6c9cSStella Laurenzo // PyNamedAttribute.
1421436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1422436c6c9cSStella Laurenzo 
1423436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1424436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
1425436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
1426436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
1427436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
1428436c6c9cSStella Laurenzo       attr);
1429436c6c9cSStella Laurenzo }
1430436c6c9cSStella Laurenzo 
1431436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1432436c6c9cSStella Laurenzo // PyType.
1433436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1434436c6c9cSStella Laurenzo 
1435436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) {
1436436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
1437436c6c9cSStella Laurenzo }
1438436c6c9cSStella Laurenzo 
1439436c6c9cSStella Laurenzo py::object PyType::getCapsule() {
1440436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1441436c6c9cSStella Laurenzo }
1442436c6c9cSStella Laurenzo 
1443436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) {
1444436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1445436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
1446436c6c9cSStella Laurenzo     throw py::error_already_set();
1447436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1448436c6c9cSStella Laurenzo                 rawType);
1449436c6c9cSStella Laurenzo }
1450436c6c9cSStella Laurenzo 
1451436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1452436c6c9cSStella Laurenzo // PyValue and subclases.
1453436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1454436c6c9cSStella Laurenzo 
1455436c6c9cSStella Laurenzo namespace {
1456436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be
1457436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed
1458436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes.
1459436c6c9cSStella Laurenzo template <typename DerivedTy>
1460436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue {
1461436c6c9cSStella Laurenzo public:
1462436c6c9cSStella Laurenzo   // Derived classes must define statics for:
1463436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
1464436c6c9cSStella Laurenzo   //   const char *pyClassName
1465436c6c9cSStella Laurenzo   // and redefine bindDerived.
1466436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, PyValue>;
1467436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirValue);
1468436c6c9cSStella Laurenzo 
1469436c6c9cSStella Laurenzo   PyConcreteValue() = default;
1470436c6c9cSStella Laurenzo   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1471436c6c9cSStella Laurenzo       : PyValue(operationRef, value) {}
1472436c6c9cSStella Laurenzo   PyConcreteValue(PyValue &orig)
1473436c6c9cSStella Laurenzo       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1474436c6c9cSStella Laurenzo 
1475436c6c9cSStella Laurenzo   /// Attempts to cast the original value to the derived type and throws on
1476436c6c9cSStella Laurenzo   /// type mismatches.
1477436c6c9cSStella Laurenzo   static MlirValue castFrom(PyValue &orig) {
1478436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig.get())) {
1479436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1480436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1481436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
1482436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
1483436c6c9cSStella Laurenzo     }
1484436c6c9cSStella Laurenzo     return orig.get();
1485436c6c9cSStella Laurenzo   }
1486436c6c9cSStella Laurenzo 
1487436c6c9cSStella Laurenzo   /// Binds the Python module objects to functions of this class.
1488436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1489436c6c9cSStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName);
1490436c6c9cSStella Laurenzo     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1491436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
1492436c6c9cSStella Laurenzo   }
1493436c6c9cSStella Laurenzo 
1494436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
1495436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
1496436c6c9cSStella Laurenzo };
1497436c6c9cSStella Laurenzo 
1498436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
1499436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1500436c6c9cSStella Laurenzo public:
1501436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1502436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
1503436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1504436c6c9cSStella Laurenzo 
1505436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1506436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1507436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
1508436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
1509436c6c9cSStella Laurenzo     });
1510436c6c9cSStella Laurenzo     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1511436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
1512436c6c9cSStella Laurenzo     });
1513436c6c9cSStella Laurenzo     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1514436c6c9cSStella Laurenzo       return mlirBlockArgumentSetType(self.get(), type);
1515436c6c9cSStella Laurenzo     });
1516436c6c9cSStella Laurenzo   }
1517436c6c9cSStella Laurenzo };
1518436c6c9cSStella Laurenzo 
1519436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult.
1520436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> {
1521436c6c9cSStella Laurenzo public:
1522436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1523436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResult";
1524436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1525436c6c9cSStella Laurenzo 
1526436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1527436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyOpResult &self) {
1528436c6c9cSStella Laurenzo       assert(
1529436c6c9cSStella Laurenzo           mlirOperationEqual(self.getParentOperation()->get(),
1530436c6c9cSStella Laurenzo                              mlirOpResultGetOwner(self.get())) &&
1531436c6c9cSStella Laurenzo           "expected the owner of the value in Python to match that in the IR");
1532436c6c9cSStella Laurenzo       return self.getParentOperation();
1533436c6c9cSStella Laurenzo     });
1534436c6c9cSStella Laurenzo     c.def_property_readonly("result_number", [](PyOpResult &self) {
1535436c6c9cSStella Laurenzo       return mlirOpResultGetResultNumber(self.get());
1536436c6c9cSStella Laurenzo     });
1537436c6c9cSStella Laurenzo   }
1538436c6c9cSStella Laurenzo };
1539436c6c9cSStella Laurenzo 
1540436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
1541436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
1542436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
1543436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
1544436c6c9cSStella Laurenzo class PyBlockArgumentList {
1545436c6c9cSStella Laurenzo public:
1546436c6c9cSStella Laurenzo   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1547436c6c9cSStella Laurenzo       : operation(std::move(operation)), block(block) {}
1548436c6c9cSStella Laurenzo 
1549436c6c9cSStella Laurenzo   /// Returns the length of the block argument list.
1550436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1551436c6c9cSStella Laurenzo     operation->checkValid();
1552436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
1553436c6c9cSStella Laurenzo   }
1554436c6c9cSStella Laurenzo 
1555436c6c9cSStella Laurenzo   /// Returns `index`-th element of the block argument list.
1556436c6c9cSStella Laurenzo   PyBlockArgument dunderGetItem(intptr_t index) {
1557436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1558436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1559436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
1560436c6c9cSStella Laurenzo     }
1561436c6c9cSStella Laurenzo     PyValue value(operation, mlirBlockGetArgument(block, index));
1562436c6c9cSStella Laurenzo     return PyBlockArgument(value);
1563436c6c9cSStella Laurenzo   }
1564436c6c9cSStella Laurenzo 
1565436c6c9cSStella Laurenzo   /// Defines a Python class in the bindings.
1566436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1567436c6c9cSStella Laurenzo     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1568436c6c9cSStella Laurenzo         .def("__len__", &PyBlockArgumentList::dunderLen)
1569436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1570436c6c9cSStella Laurenzo   }
1571436c6c9cSStella Laurenzo 
1572436c6c9cSStella Laurenzo private:
1573436c6c9cSStella Laurenzo   PyOperationRef operation;
1574436c6c9cSStella Laurenzo   MlirBlock block;
1575436c6c9cSStella Laurenzo };
1576436c6c9cSStella Laurenzo 
1577436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
1578436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1579436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1580436c6c9cSStella Laurenzo /// operation.
1581436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1582436c6c9cSStella Laurenzo public:
1583436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
1584436c6c9cSStella Laurenzo 
1585436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1586436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
1587436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1588436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1589436c6c9cSStella Laurenzo                                : length,
1590436c6c9cSStella Laurenzo                   step),
1591436c6c9cSStella Laurenzo         operation(operation) {}
1592436c6c9cSStella Laurenzo 
1593436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1594436c6c9cSStella Laurenzo     operation->checkValid();
1595436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
1596436c6c9cSStella Laurenzo   }
1597436c6c9cSStella Laurenzo 
1598436c6c9cSStella Laurenzo   PyValue getElement(intptr_t pos) {
1599436c6c9cSStella Laurenzo     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1600436c6c9cSStella Laurenzo   }
1601436c6c9cSStella Laurenzo 
1602436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1603436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
1604436c6c9cSStella Laurenzo   }
1605436c6c9cSStella Laurenzo 
1606436c6c9cSStella Laurenzo private:
1607436c6c9cSStella Laurenzo   PyOperationRef operation;
1608436c6c9cSStella Laurenzo };
1609436c6c9cSStella Laurenzo 
1610436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive
1611436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1612436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1613436c6c9cSStella Laurenzo /// operation.
1614436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1615436c6c9cSStella Laurenzo public:
1616436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResultList";
1617436c6c9cSStella Laurenzo 
1618436c6c9cSStella Laurenzo   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1619436c6c9cSStella Laurenzo                  intptr_t length = -1, intptr_t step = 1)
1620436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1621436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumResults(operation->get())
1622436c6c9cSStella Laurenzo                                : length,
1623436c6c9cSStella Laurenzo                   step),
1624436c6c9cSStella Laurenzo         operation(operation) {}
1625436c6c9cSStella Laurenzo 
1626436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1627436c6c9cSStella Laurenzo     operation->checkValid();
1628436c6c9cSStella Laurenzo     return mlirOperationGetNumResults(operation->get());
1629436c6c9cSStella Laurenzo   }
1630436c6c9cSStella Laurenzo 
1631436c6c9cSStella Laurenzo   PyOpResult getElement(intptr_t index) {
1632436c6c9cSStella Laurenzo     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1633436c6c9cSStella Laurenzo     return PyOpResult(value);
1634436c6c9cSStella Laurenzo   }
1635436c6c9cSStella Laurenzo 
1636436c6c9cSStella Laurenzo   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1637436c6c9cSStella Laurenzo     return PyOpResultList(operation, startIndex, length, step);
1638436c6c9cSStella Laurenzo   }
1639436c6c9cSStella Laurenzo 
1640436c6c9cSStella Laurenzo private:
1641436c6c9cSStella Laurenzo   PyOperationRef operation;
1642436c6c9cSStella Laurenzo };
1643436c6c9cSStella Laurenzo 
1644436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
1645436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
1646436c6c9cSStella Laurenzo class PyOpAttributeMap {
1647436c6c9cSStella Laurenzo public:
1648436c6c9cSStella Laurenzo   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1649436c6c9cSStella Laurenzo 
1650436c6c9cSStella Laurenzo   PyAttribute dunderGetItemNamed(const std::string &name) {
1651436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1652436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
1653436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1654436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1655436c6c9cSStella Laurenzo                        "attempt to access a non-existent attribute");
1656436c6c9cSStella Laurenzo     }
1657436c6c9cSStella Laurenzo     return PyAttribute(operation->getContext(), attr);
1658436c6c9cSStella Laurenzo   }
1659436c6c9cSStella Laurenzo 
1660436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1661436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1662436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1663436c6c9cSStella Laurenzo                        "attempt to access out of bounds attribute");
1664436c6c9cSStella Laurenzo     }
1665436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
1666436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
1667436c6c9cSStella Laurenzo     return PyNamedAttribute(
1668436c6c9cSStella Laurenzo         namedAttr.attribute,
1669436c6c9cSStella Laurenzo         std::string(mlirIdentifierStr(namedAttr.name).data));
1670436c6c9cSStella Laurenzo   }
1671436c6c9cSStella Laurenzo 
1672436c6c9cSStella Laurenzo   void dunderSetItem(const std::string &name, PyAttribute attr) {
1673436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1674436c6c9cSStella Laurenzo                                     attr);
1675436c6c9cSStella Laurenzo   }
1676436c6c9cSStella Laurenzo 
1677436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
1678436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1679436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
1680436c6c9cSStella Laurenzo     if (!removed)
1681436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1682436c6c9cSStella Laurenzo                        "attempt to delete a non-existent attribute");
1683436c6c9cSStella Laurenzo   }
1684436c6c9cSStella Laurenzo 
1685436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1686436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
1687436c6c9cSStella Laurenzo   }
1688436c6c9cSStella Laurenzo 
1689436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
1690436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1691436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
1692436c6c9cSStella Laurenzo   }
1693436c6c9cSStella Laurenzo 
1694436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1695436c6c9cSStella Laurenzo     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1696436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
1697436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
1698436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1699436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1700436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1701436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1702436c6c9cSStella Laurenzo   }
1703436c6c9cSStella Laurenzo 
1704436c6c9cSStella Laurenzo private:
1705436c6c9cSStella Laurenzo   PyOperationRef operation;
1706436c6c9cSStella Laurenzo };
1707436c6c9cSStella Laurenzo 
1708436c6c9cSStella Laurenzo } // end namespace
1709436c6c9cSStella Laurenzo 
1710436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1711436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
1712436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1713436c6c9cSStella Laurenzo 
1714436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) {
1715436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1716*caa159f0SNicolas Vasilache   // Mapping of Global functions
1717*caa159f0SNicolas Vasilache   //----------------------------------------------------------------------------
1718*caa159f0SNicolas Vasilache   m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); });
1719*caa159f0SNicolas Vasilache 
1720*caa159f0SNicolas Vasilache   //----------------------------------------------------------------------------
1721436c6c9cSStella Laurenzo   // Mapping of MlirContext
1722436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1723436c6c9cSStella Laurenzo   py::class_<PyMlirContext>(m, "Context")
1724436c6c9cSStella Laurenzo       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1725436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1726436c6c9cSStella Laurenzo       .def("_get_context_again",
1727436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
1728436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1729436c6c9cSStella Laurenzo              return ref.releaseObject();
1730436c6c9cSStella Laurenzo            })
1731436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1732436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1733436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1734436c6c9cSStella Laurenzo                              &PyMlirContext::getCapsule)
1735436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1736436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
1737436c6c9cSStella Laurenzo       .def("__exit__", &PyMlirContext::contextExit)
1738436c6c9cSStella Laurenzo       .def_property_readonly_static(
1739436c6c9cSStella Laurenzo           "current",
1740436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1741436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
1742436c6c9cSStella Laurenzo             if (!context)
1743436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Context");
1744436c6c9cSStella Laurenzo             return context;
1745436c6c9cSStella Laurenzo           },
1746436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
1747436c6c9cSStella Laurenzo       .def_property_readonly(
1748436c6c9cSStella Laurenzo           "dialects",
1749436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1750436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
1751436c6c9cSStella Laurenzo       .def_property_readonly(
1752436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1753436c6c9cSStella Laurenzo           "Alias for 'dialect'")
1754436c6c9cSStella Laurenzo       .def(
1755436c6c9cSStella Laurenzo           "get_dialect_descriptor",
1756436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
1757436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
1758436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
1759436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
1760436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
1761436c6c9cSStella Laurenzo                                Twine("Dialect '") + name + "' not found");
1762436c6c9cSStella Laurenzo             }
1763436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
1764436c6c9cSStella Laurenzo           },
1765436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
1766436c6c9cSStella Laurenzo       .def_property(
1767436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
1768436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
1769436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
1770436c6c9cSStella Laurenzo           },
1771436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
1772436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
17739a9214faSStella Laurenzo           })
1774*caa159f0SNicolas Vasilache       .def("enable_multithreading",
1775*caa159f0SNicolas Vasilache            [](PyMlirContext &self, bool enable) {
1776*caa159f0SNicolas Vasilache              mlirContextEnableMultithreading(self.get(), enable);
1777*caa159f0SNicolas Vasilache            })
17789a9214faSStella Laurenzo       .def("is_registered_operation",
17799a9214faSStella Laurenzo            [](PyMlirContext &self, std::string &name) {
17809a9214faSStella Laurenzo              return mlirContextIsRegisteredOperation(
17819a9214faSStella Laurenzo                  self.get(), MlirStringRef{name.data(), name.size()});
1782436c6c9cSStella Laurenzo            });
1783436c6c9cSStella Laurenzo 
1784436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1785436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
1786436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1787436c6c9cSStella Laurenzo   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1788436c6c9cSStella Laurenzo       .def_property_readonly("namespace",
1789436c6c9cSStella Laurenzo                              [](PyDialectDescriptor &self) {
1790436c6c9cSStella Laurenzo                                MlirStringRef ns =
1791436c6c9cSStella Laurenzo                                    mlirDialectGetNamespace(self.get());
1792436c6c9cSStella Laurenzo                                return py::str(ns.data, ns.length);
1793436c6c9cSStella Laurenzo                              })
1794436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
1795436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1796436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
1797436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
1798436c6c9cSStella Laurenzo         repr.append(">");
1799436c6c9cSStella Laurenzo         return repr;
1800436c6c9cSStella Laurenzo       });
1801436c6c9cSStella Laurenzo 
1802436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1803436c6c9cSStella Laurenzo   // Mapping of PyDialects
1804436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1805436c6c9cSStella Laurenzo   py::class_<PyDialects>(m, "Dialects")
1806436c6c9cSStella Laurenzo       .def("__getitem__",
1807436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
1808436c6c9cSStella Laurenzo              MlirDialect dialect =
1809436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
1810436c6c9cSStella Laurenzo              py::object descriptor =
1811436c6c9cSStella Laurenzo                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1812436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
1813436c6c9cSStella Laurenzo            })
1814436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1815436c6c9cSStella Laurenzo         MlirDialect dialect =
1816436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
1817436c6c9cSStella Laurenzo         py::object descriptor =
1818436c6c9cSStella Laurenzo             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1819436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
1820436c6c9cSStella Laurenzo       });
1821436c6c9cSStella Laurenzo 
1822436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1823436c6c9cSStella Laurenzo   // Mapping of PyDialect
1824436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1825436c6c9cSStella Laurenzo   py::class_<PyDialect>(m, "Dialect")
1826436c6c9cSStella Laurenzo       .def(py::init<py::object>(), "descriptor")
1827436c6c9cSStella Laurenzo       .def_property_readonly(
1828436c6c9cSStella Laurenzo           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1829436c6c9cSStella Laurenzo       .def("__repr__", [](py::object self) {
1830436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
1831436c6c9cSStella Laurenzo         return py::str("<Dialect ") +
1832436c6c9cSStella Laurenzo                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1833436c6c9cSStella Laurenzo                clazz.attr("__module__") + py::str(".") +
1834436c6c9cSStella Laurenzo                clazz.attr("__name__") + py::str(")>");
1835436c6c9cSStella Laurenzo       });
1836436c6c9cSStella Laurenzo 
1837436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1838436c6c9cSStella Laurenzo   // Mapping of Location
1839436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1840436c6c9cSStella Laurenzo   py::class_<PyLocation>(m, "Location")
1841436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1842436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1843436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
1844436c6c9cSStella Laurenzo       .def("__exit__", &PyLocation::contextExit)
1845436c6c9cSStella Laurenzo       .def("__eq__",
1846436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
1847436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
1848436c6c9cSStella Laurenzo            })
1849436c6c9cSStella Laurenzo       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1850436c6c9cSStella Laurenzo       .def_property_readonly_static(
1851436c6c9cSStella Laurenzo           "current",
1852436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1853436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
1854436c6c9cSStella Laurenzo             if (!loc)
1855436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Location");
1856436c6c9cSStella Laurenzo             return loc;
1857436c6c9cSStella Laurenzo           },
1858436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
1859436c6c9cSStella Laurenzo       .def_static(
1860436c6c9cSStella Laurenzo           "unknown",
1861436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
1862436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
1863436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
1864436c6c9cSStella Laurenzo           },
1865436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
1866436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
1867436c6c9cSStella Laurenzo       .def_static(
1868436c6c9cSStella Laurenzo           "file",
1869436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
1870436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
1871436c6c9cSStella Laurenzo             return PyLocation(
1872436c6c9cSStella Laurenzo                 context->getRef(),
1873436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
1874436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
1875436c6c9cSStella Laurenzo           },
1876436c6c9cSStella Laurenzo           py::arg("filename"), py::arg("line"), py::arg("col"),
1877436c6c9cSStella Laurenzo           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1878436c6c9cSStella Laurenzo       .def_property_readonly(
1879436c6c9cSStella Laurenzo           "context",
1880436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
1881436c6c9cSStella Laurenzo           "Context that owns the Location")
1882436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
1883436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
1884436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
1885436c6c9cSStella Laurenzo                           printAccum.getUserData());
1886436c6c9cSStella Laurenzo         return printAccum.join();
1887436c6c9cSStella Laurenzo       });
1888436c6c9cSStella Laurenzo 
1889436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1890436c6c9cSStella Laurenzo   // Mapping of Module
1891436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1892436c6c9cSStella Laurenzo   py::class_<PyModule>(m, "Module")
1893436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1894436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1895436c6c9cSStella Laurenzo       .def_static(
1896436c6c9cSStella Laurenzo           "parse",
1897436c6c9cSStella Laurenzo           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1898436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
1899436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
1900436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
1901436c6c9cSStella Laurenzo             // in C API.
1902436c6c9cSStella Laurenzo             if (mlirModuleIsNull(module)) {
1903436c6c9cSStella Laurenzo               throw SetPyError(
1904436c6c9cSStella Laurenzo                   PyExc_ValueError,
1905436c6c9cSStella Laurenzo                   "Unable to parse module assembly (see diagnostics)");
1906436c6c9cSStella Laurenzo             }
1907436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
1908436c6c9cSStella Laurenzo           },
1909436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
1910436c6c9cSStella Laurenzo           kModuleParseDocstring)
1911436c6c9cSStella Laurenzo       .def_static(
1912436c6c9cSStella Laurenzo           "create",
1913436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
1914436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
1915436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
1916436c6c9cSStella Laurenzo           },
1917436c6c9cSStella Laurenzo           py::arg("loc") = py::none(), "Creates an empty module")
1918436c6c9cSStella Laurenzo       .def_property_readonly(
1919436c6c9cSStella Laurenzo           "context",
1920436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
1921436c6c9cSStella Laurenzo           "Context that created the Module")
1922436c6c9cSStella Laurenzo       .def_property_readonly(
1923436c6c9cSStella Laurenzo           "operation",
1924436c6c9cSStella Laurenzo           [](PyModule &self) {
1925436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
1926436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
1927436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
1928436c6c9cSStella Laurenzo                 .releaseObject();
1929436c6c9cSStella Laurenzo           },
1930436c6c9cSStella Laurenzo           "Accesses the module as an operation")
1931436c6c9cSStella Laurenzo       .def_property_readonly(
1932436c6c9cSStella Laurenzo           "body",
1933436c6c9cSStella Laurenzo           [](PyModule &self) {
1934436c6c9cSStella Laurenzo             PyOperationRef module_op = PyOperation::forOperation(
1935436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
1936436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
1937436c6c9cSStella Laurenzo             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
1938436c6c9cSStella Laurenzo             return returnBlock;
1939436c6c9cSStella Laurenzo           },
1940436c6c9cSStella Laurenzo           "Return the block for this module")
1941436c6c9cSStella Laurenzo       .def(
1942436c6c9cSStella Laurenzo           "dump",
1943436c6c9cSStella Laurenzo           [](PyModule &self) {
1944436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
1945436c6c9cSStella Laurenzo           },
1946436c6c9cSStella Laurenzo           kDumpDocstring)
1947436c6c9cSStella Laurenzo       .def(
1948436c6c9cSStella Laurenzo           "__str__",
1949436c6c9cSStella Laurenzo           [](PyModule &self) {
1950436c6c9cSStella Laurenzo             MlirOperation operation = mlirModuleGetOperation(self.get());
1951436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
1952436c6c9cSStella Laurenzo             mlirOperationPrint(operation, printAccum.getCallback(),
1953436c6c9cSStella Laurenzo                                printAccum.getUserData());
1954436c6c9cSStella Laurenzo             return printAccum.join();
1955436c6c9cSStella Laurenzo           },
1956436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
1957436c6c9cSStella Laurenzo 
1958436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1959436c6c9cSStella Laurenzo   // Mapping of Operation.
1960436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1961436c6c9cSStella Laurenzo   py::class_<PyOperationBase>(m, "_OperationBase")
1962436c6c9cSStella Laurenzo       .def("__eq__",
1963436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
1964436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
1965436c6c9cSStella Laurenzo            })
1966436c6c9cSStella Laurenzo       .def("__eq__",
1967436c6c9cSStella Laurenzo            [](PyOperationBase &self, py::object other) { return false; })
1968436c6c9cSStella Laurenzo       .def_property_readonly("attributes",
1969436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
1970436c6c9cSStella Laurenzo                                return PyOpAttributeMap(
1971436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
1972436c6c9cSStella Laurenzo                              })
1973436c6c9cSStella Laurenzo       .def_property_readonly("operands",
1974436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
1975436c6c9cSStella Laurenzo                                return PyOpOperandList(
1976436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
1977436c6c9cSStella Laurenzo                              })
1978436c6c9cSStella Laurenzo       .def_property_readonly("regions",
1979436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
1980436c6c9cSStella Laurenzo                                return PyRegionList(
1981436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
1982436c6c9cSStella Laurenzo                              })
1983436c6c9cSStella Laurenzo       .def_property_readonly(
1984436c6c9cSStella Laurenzo           "results",
1985436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
1986436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
1987436c6c9cSStella Laurenzo           },
1988436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
1989436c6c9cSStella Laurenzo       .def_property_readonly(
1990436c6c9cSStella Laurenzo           "result",
1991436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
1992436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
1993436c6c9cSStella Laurenzo             auto numResults = mlirOperationGetNumResults(operation);
1994436c6c9cSStella Laurenzo             if (numResults != 1) {
1995436c6c9cSStella Laurenzo               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1996436c6c9cSStella Laurenzo               throw SetPyError(
1997436c6c9cSStella Laurenzo                   PyExc_ValueError,
1998436c6c9cSStella Laurenzo                   Twine("Cannot call .result on operation ") +
1999436c6c9cSStella Laurenzo                       StringRef(name.data, name.length) + " which has " +
2000436c6c9cSStella Laurenzo                       Twine(numResults) +
2001436c6c9cSStella Laurenzo                       " results (it is only valid for operations with a "
2002436c6c9cSStella Laurenzo                       "single result)");
2003436c6c9cSStella Laurenzo             }
2004436c6c9cSStella Laurenzo             return PyOpResult(operation.getRef(),
2005436c6c9cSStella Laurenzo                               mlirOperationGetResult(operation, 0));
2006436c6c9cSStella Laurenzo           },
2007436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
2008436c6c9cSStella Laurenzo           "otherwise).")
2009436c6c9cSStella Laurenzo       .def("__iter__",
2010436c6c9cSStella Laurenzo            [](PyOperationBase &self) {
2011436c6c9cSStella Laurenzo              return PyRegionIterator(self.getOperation().getRef());
2012436c6c9cSStella Laurenzo            })
2013436c6c9cSStella Laurenzo       .def(
2014436c6c9cSStella Laurenzo           "__str__",
2015436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2016436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
2017436c6c9cSStella Laurenzo                                /*largeElementsLimit=*/llvm::None,
2018436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
2019436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
2020436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
2021436c6c9cSStella Laurenzo                                /*useLocalScope=*/false);
2022436c6c9cSStella Laurenzo           },
2023436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
2024436c6c9cSStella Laurenzo       .def("print", &PyOperationBase::print,
2025436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
2026436c6c9cSStella Laurenzo            py::arg("file") = py::none(), py::arg("binary") = false,
2027436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2028436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2029436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2030436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2031436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2032436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
2033436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
2034436c6c9cSStella Laurenzo            py::arg("binary") = false,
2035436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2036436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2037436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2038436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2039436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2040436c6c9cSStella Laurenzo       .def(
2041436c6c9cSStella Laurenzo           "verify",
2042436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2043436c6c9cSStella Laurenzo             return mlirOperationVerify(self.getOperation());
2044436c6c9cSStella Laurenzo           },
2045436c6c9cSStella Laurenzo           "Verify the operation and return true if it passes, false if it "
2046436c6c9cSStella Laurenzo           "fails.");
2047436c6c9cSStella Laurenzo 
2048436c6c9cSStella Laurenzo   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2049436c6c9cSStella Laurenzo       .def_static("create", &PyOperation::create, py::arg("name"),
2050436c6c9cSStella Laurenzo                   py::arg("results") = py::none(),
2051436c6c9cSStella Laurenzo                   py::arg("operands") = py::none(),
2052436c6c9cSStella Laurenzo                   py::arg("attributes") = py::none(),
2053436c6c9cSStella Laurenzo                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2054436c6c9cSStella Laurenzo                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2055436c6c9cSStella Laurenzo                   kOperationCreateDocstring)
20560126e906SJohn Demme       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
20570126e906SJohn Demme                              &PyOperation::getCapsule)
20580126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2059436c6c9cSStella Laurenzo       .def_property_readonly("name",
2060436c6c9cSStella Laurenzo                              [](PyOperation &self) {
2061436c6c9cSStella Laurenzo                                MlirOperation operation = self.get();
2062436c6c9cSStella Laurenzo                                MlirStringRef name = mlirIdentifierStr(
2063436c6c9cSStella Laurenzo                                    mlirOperationGetName(operation));
2064436c6c9cSStella Laurenzo                                return py::str(name.data, name.length);
2065436c6c9cSStella Laurenzo                              })
2066436c6c9cSStella Laurenzo       .def_property_readonly(
2067436c6c9cSStella Laurenzo           "context",
2068436c6c9cSStella Laurenzo           [](PyOperation &self) { return self.getContext().getObject(); },
2069436c6c9cSStella Laurenzo           "Context that owns the Operation")
2070436c6c9cSStella Laurenzo       .def_property_readonly("opview", &PyOperation::createOpView);
2071436c6c9cSStella Laurenzo 
2072436c6c9cSStella Laurenzo   auto opViewClass =
2073436c6c9cSStella Laurenzo       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2074436c6c9cSStella Laurenzo           .def(py::init<py::object>())
2075436c6c9cSStella Laurenzo           .def_property_readonly("operation", &PyOpView::getOperationObject)
2076436c6c9cSStella Laurenzo           .def_property_readonly(
2077436c6c9cSStella Laurenzo               "context",
2078436c6c9cSStella Laurenzo               [](PyOpView &self) {
2079436c6c9cSStella Laurenzo                 return self.getOperation().getContext().getObject();
2080436c6c9cSStella Laurenzo               },
2081436c6c9cSStella Laurenzo               "Context that owns the Operation")
2082436c6c9cSStella Laurenzo           .def("__str__", [](PyOpView &self) {
2083436c6c9cSStella Laurenzo             return py::str(self.getOperationObject());
2084436c6c9cSStella Laurenzo           });
2085436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2086436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2087436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2088436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
2089436c6c9cSStella Laurenzo       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2090436c6c9cSStella Laurenzo       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2091436c6c9cSStella Laurenzo       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2092436c6c9cSStella Laurenzo       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2093436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
2094436c6c9cSStella Laurenzo 
2095436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2096436c6c9cSStella Laurenzo   // Mapping of PyRegion.
2097436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2098436c6c9cSStella Laurenzo   py::class_<PyRegion>(m, "Region")
2099436c6c9cSStella Laurenzo       .def_property_readonly(
2100436c6c9cSStella Laurenzo           "blocks",
2101436c6c9cSStella Laurenzo           [](PyRegion &self) {
2102436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
2103436c6c9cSStella Laurenzo           },
2104436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
2105436c6c9cSStella Laurenzo       .def(
2106436c6c9cSStella Laurenzo           "__iter__",
2107436c6c9cSStella Laurenzo           [](PyRegion &self) {
2108436c6c9cSStella Laurenzo             self.checkValid();
2109436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2110436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
2111436c6c9cSStella Laurenzo           },
2112436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
2113436c6c9cSStella Laurenzo       .def("__eq__",
2114436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
2115436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2116436c6c9cSStella Laurenzo            })
2117436c6c9cSStella Laurenzo       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2118436c6c9cSStella Laurenzo 
2119436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2120436c6c9cSStella Laurenzo   // Mapping of PyBlock.
2121436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2122436c6c9cSStella Laurenzo   py::class_<PyBlock>(m, "Block")
2123436c6c9cSStella Laurenzo       .def_property_readonly(
2124436c6c9cSStella Laurenzo           "arguments",
2125436c6c9cSStella Laurenzo           [](PyBlock &self) {
2126436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
2127436c6c9cSStella Laurenzo           },
2128436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
2129436c6c9cSStella Laurenzo       .def_property_readonly(
2130436c6c9cSStella Laurenzo           "operations",
2131436c6c9cSStella Laurenzo           [](PyBlock &self) {
2132436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
2133436c6c9cSStella Laurenzo           },
2134436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
2135436c6c9cSStella Laurenzo       .def(
2136436c6c9cSStella Laurenzo           "__iter__",
2137436c6c9cSStella Laurenzo           [](PyBlock &self) {
2138436c6c9cSStella Laurenzo             self.checkValid();
2139436c6c9cSStella Laurenzo             MlirOperation firstOperation =
2140436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
2141436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
2142436c6c9cSStella Laurenzo                                        firstOperation);
2143436c6c9cSStella Laurenzo           },
2144436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
2145436c6c9cSStella Laurenzo       .def("__eq__",
2146436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
2147436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2148436c6c9cSStella Laurenzo            })
2149436c6c9cSStella Laurenzo       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2150436c6c9cSStella Laurenzo       .def(
2151436c6c9cSStella Laurenzo           "__str__",
2152436c6c9cSStella Laurenzo           [](PyBlock &self) {
2153436c6c9cSStella Laurenzo             self.checkValid();
2154436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2155436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
2156436c6c9cSStella Laurenzo                            printAccum.getUserData());
2157436c6c9cSStella Laurenzo             return printAccum.join();
2158436c6c9cSStella Laurenzo           },
2159436c6c9cSStella Laurenzo           "Returns the assembly form of the block.");
2160436c6c9cSStella Laurenzo 
2161436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2162436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
2163436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2164436c6c9cSStella Laurenzo 
2165436c6c9cSStella Laurenzo   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2166436c6c9cSStella Laurenzo       .def(py::init<PyBlock &>(), py::arg("block"),
2167436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
2168436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
2169436c6c9cSStella Laurenzo       .def("__exit__", &PyInsertionPoint::contextExit)
2170436c6c9cSStella Laurenzo       .def_property_readonly_static(
2171436c6c9cSStella Laurenzo           "current",
2172436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2173436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2174436c6c9cSStella Laurenzo             if (!ip)
2175436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2176436c6c9cSStella Laurenzo             return ip;
2177436c6c9cSStella Laurenzo           },
2178436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
2179436c6c9cSStella Laurenzo           "ValueError if none has been set")
2180436c6c9cSStella Laurenzo       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2181436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
2182436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2183436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts at the beginning of the block.")
2184436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2185436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts before the block terminator.")
2186436c6c9cSStella Laurenzo       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2187436c6c9cSStella Laurenzo            "Inserts an operation.");
2188436c6c9cSStella Laurenzo 
2189436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2190436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
2191436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2192436c6c9cSStella Laurenzo   py::class_<PyAttribute>(m, "Attribute")
2193436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2194436c6c9cSStella Laurenzo                              &PyAttribute::getCapsule)
2195436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2196436c6c9cSStella Laurenzo       .def_static(
2197436c6c9cSStella Laurenzo           "parse",
2198436c6c9cSStella Laurenzo           [](std::string attrSpec, DefaultingPyMlirContext context) {
2199436c6c9cSStella Laurenzo             MlirAttribute type = mlirAttributeParseGet(
2200436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
2201436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2202436c6c9cSStella Laurenzo             // in C API.
2203436c6c9cSStella Laurenzo             if (mlirAttributeIsNull(type)) {
2204436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2205436c6c9cSStella Laurenzo                                Twine("Unable to parse attribute: '") +
2206436c6c9cSStella Laurenzo                                    attrSpec + "'");
2207436c6c9cSStella Laurenzo             }
2208436c6c9cSStella Laurenzo             return PyAttribute(context->getRef(), type);
2209436c6c9cSStella Laurenzo           },
2210436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2211436c6c9cSStella Laurenzo           "Parses an attribute from an assembly form")
2212436c6c9cSStella Laurenzo       .def_property_readonly(
2213436c6c9cSStella Laurenzo           "context",
2214436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
2215436c6c9cSStella Laurenzo           "Context that owns the Attribute")
2216436c6c9cSStella Laurenzo       .def_property_readonly("type",
2217436c6c9cSStella Laurenzo                              [](PyAttribute &self) {
2218436c6c9cSStella Laurenzo                                return PyType(self.getContext()->getRef(),
2219436c6c9cSStella Laurenzo                                              mlirAttributeGetType(self));
2220436c6c9cSStella Laurenzo                              })
2221436c6c9cSStella Laurenzo       .def(
2222436c6c9cSStella Laurenzo           "get_named",
2223436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
2224436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
2225436c6c9cSStella Laurenzo           },
2226436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2227436c6c9cSStella Laurenzo       .def("__eq__",
2228436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2229436c6c9cSStella Laurenzo       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2230436c6c9cSStella Laurenzo       .def(
2231436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2232436c6c9cSStella Laurenzo           kDumpDocstring)
2233436c6c9cSStella Laurenzo       .def(
2234436c6c9cSStella Laurenzo           "__str__",
2235436c6c9cSStella Laurenzo           [](PyAttribute &self) {
2236436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2237436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
2238436c6c9cSStella Laurenzo                                printAccum.getUserData());
2239436c6c9cSStella Laurenzo             return printAccum.join();
2240436c6c9cSStella Laurenzo           },
2241436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
2242436c6c9cSStella Laurenzo       .def("__repr__", [](PyAttribute &self) {
2243436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2244436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2245436c6c9cSStella Laurenzo         // However, attribute values are generally considered useful and are
2246436c6c9cSStella Laurenzo         // printed. This may need to be re-evaluated if debug dumps end up
2247436c6c9cSStella Laurenzo         // being excessive.
2248436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2249436c6c9cSStella Laurenzo         printAccum.parts.append("Attribute(");
2250436c6c9cSStella Laurenzo         mlirAttributePrint(self, printAccum.getCallback(),
2251436c6c9cSStella Laurenzo                            printAccum.getUserData());
2252436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2253436c6c9cSStella Laurenzo         return printAccum.join();
2254436c6c9cSStella Laurenzo       });
2255436c6c9cSStella Laurenzo 
2256436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2257436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
2258436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2259436c6c9cSStella Laurenzo   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2260436c6c9cSStella Laurenzo       .def("__repr__",
2261436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
2262436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
2263436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
2264436c6c9cSStella Laurenzo              printAccum.parts.append(
2265436c6c9cSStella Laurenzo                  mlirIdentifierStr(self.namedAttr.name).data);
2266436c6c9cSStella Laurenzo              printAccum.parts.append("=");
2267436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
2268436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
2269436c6c9cSStella Laurenzo                                 printAccum.getUserData());
2270436c6c9cSStella Laurenzo              printAccum.parts.append(")");
2271436c6c9cSStella Laurenzo              return printAccum.join();
2272436c6c9cSStella Laurenzo            })
2273436c6c9cSStella Laurenzo       .def_property_readonly(
2274436c6c9cSStella Laurenzo           "name",
2275436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2276436c6c9cSStella Laurenzo             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2277436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
2278436c6c9cSStella Laurenzo           },
2279436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
2280436c6c9cSStella Laurenzo       .def_property_readonly(
2281436c6c9cSStella Laurenzo           "attr",
2282436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2283436c6c9cSStella Laurenzo             // TODO: When named attribute is removed/refactored, also remove
2284436c6c9cSStella Laurenzo             // this constructor (it does an inefficient table lookup).
2285436c6c9cSStella Laurenzo             auto contextRef = PyMlirContext::forContext(
2286436c6c9cSStella Laurenzo                 mlirAttributeGetContext(self.namedAttr.attribute));
2287436c6c9cSStella Laurenzo             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2288436c6c9cSStella Laurenzo           },
2289436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(),
2290436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
2291436c6c9cSStella Laurenzo 
2292436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2293436c6c9cSStella Laurenzo   // Mapping of PyType.
2294436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2295436c6c9cSStella Laurenzo   py::class_<PyType>(m, "Type")
2296436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2297436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2298436c6c9cSStella Laurenzo       .def_static(
2299436c6c9cSStella Laurenzo           "parse",
2300436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
2301436c6c9cSStella Laurenzo             MlirType type =
2302436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2303436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2304436c6c9cSStella Laurenzo             // in C API.
2305436c6c9cSStella Laurenzo             if (mlirTypeIsNull(type)) {
2306436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2307436c6c9cSStella Laurenzo                                Twine("Unable to parse type: '") + typeSpec +
2308436c6c9cSStella Laurenzo                                    "'");
2309436c6c9cSStella Laurenzo             }
2310436c6c9cSStella Laurenzo             return PyType(context->getRef(), type);
2311436c6c9cSStella Laurenzo           },
2312436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2313436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
2314436c6c9cSStella Laurenzo       .def_property_readonly(
2315436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
2316436c6c9cSStella Laurenzo           "Context that owns the Type")
2317436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2318436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2319436c6c9cSStella Laurenzo       .def(
2320436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2321436c6c9cSStella Laurenzo       .def(
2322436c6c9cSStella Laurenzo           "__str__",
2323436c6c9cSStella Laurenzo           [](PyType &self) {
2324436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2325436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
2326436c6c9cSStella Laurenzo                           printAccum.getUserData());
2327436c6c9cSStella Laurenzo             return printAccum.join();
2328436c6c9cSStella Laurenzo           },
2329436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
2330436c6c9cSStella Laurenzo       .def("__repr__", [](PyType &self) {
2331436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2332436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2333436c6c9cSStella Laurenzo         // However, types are an exception as they typically have compact
2334436c6c9cSStella Laurenzo         // assembly forms and printing them is useful.
2335436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2336436c6c9cSStella Laurenzo         printAccum.parts.append("Type(");
2337436c6c9cSStella Laurenzo         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2338436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2339436c6c9cSStella Laurenzo         return printAccum.join();
2340436c6c9cSStella Laurenzo       });
2341436c6c9cSStella Laurenzo 
2342436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2343436c6c9cSStella Laurenzo   // Mapping of Value.
2344436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2345436c6c9cSStella Laurenzo   py::class_<PyValue>(m, "Value")
2346436c6c9cSStella Laurenzo       .def_property_readonly(
2347436c6c9cSStella Laurenzo           "context",
2348436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2349436c6c9cSStella Laurenzo           "Context in which the value lives.")
2350436c6c9cSStella Laurenzo       .def(
2351436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2352436c6c9cSStella Laurenzo           kDumpDocstring)
2353436c6c9cSStella Laurenzo       .def("__eq__",
2354436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
2355436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2356436c6c9cSStella Laurenzo            })
2357436c6c9cSStella Laurenzo       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2358436c6c9cSStella Laurenzo       .def(
2359436c6c9cSStella Laurenzo           "__str__",
2360436c6c9cSStella Laurenzo           [](PyValue &self) {
2361436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2362436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
2363436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
2364436c6c9cSStella Laurenzo                            printAccum.getUserData());
2365436c6c9cSStella Laurenzo             printAccum.parts.append(")");
2366436c6c9cSStella Laurenzo             return printAccum.join();
2367436c6c9cSStella Laurenzo           },
2368436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
2369436c6c9cSStella Laurenzo       .def_property_readonly("type", [](PyValue &self) {
2370436c6c9cSStella Laurenzo         return PyType(self.getParentOperation()->getContext(),
2371436c6c9cSStella Laurenzo                       mlirValueGetType(self.get()));
2372436c6c9cSStella Laurenzo       });
2373436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
2374436c6c9cSStella Laurenzo   PyOpResult::bind(m);
2375436c6c9cSStella Laurenzo 
2376436c6c9cSStella Laurenzo   // Container bindings.
2377436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
2378436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
2379436c6c9cSStella Laurenzo   PyBlockList::bind(m);
2380436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
2381436c6c9cSStella Laurenzo   PyOperationList::bind(m);
2382436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
2383436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
2384436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
2385436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
2386436c6c9cSStella Laurenzo   PyRegionList::bind(m);
2387436c6c9cSStella Laurenzo }
2388