1436c6c9cSStella Laurenzo //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #include "IRModule.h"
10436c6c9cSStella Laurenzo 
11436c6c9cSStella Laurenzo #include "Globals.h"
12436c6c9cSStella Laurenzo #include "PybindUtils.h"
13436c6c9cSStella Laurenzo 
14436c6c9cSStella Laurenzo #include "mlir-c/Bindings/Python/Interop.h"
15436c6c9cSStella Laurenzo #include "mlir-c/BuiltinAttributes.h"
16436c6c9cSStella Laurenzo #include "mlir-c/BuiltinTypes.h"
174acd8457SAlex Zinenko #include "mlir-c/Debug.h"
183f3d1c90SMike Urbach #include "mlir-c/IR.h"
19436c6c9cSStella Laurenzo #include "mlir-c/Registration.h"
20436c6c9cSStella Laurenzo #include "llvm/ADT/SmallVector.h"
21436c6c9cSStella Laurenzo #include <pybind11/stl.h>
22436c6c9cSStella Laurenzo 
23436c6c9cSStella Laurenzo namespace py = pybind11;
24436c6c9cSStella Laurenzo using namespace mlir;
25436c6c9cSStella Laurenzo using namespace mlir::python;
26436c6c9cSStella Laurenzo 
27436c6c9cSStella Laurenzo using llvm::SmallVector;
28436c6c9cSStella Laurenzo using llvm::StringRef;
29436c6c9cSStella Laurenzo using llvm::Twine;
30436c6c9cSStella Laurenzo 
31436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
32436c6c9cSStella Laurenzo // Docstrings (trivial, non-duplicated docstrings are included inline).
33436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
34436c6c9cSStella Laurenzo 
35436c6c9cSStella Laurenzo static const char kContextParseTypeDocstring[] =
36436c6c9cSStella Laurenzo     R"(Parses the assembly form of a type.
37436c6c9cSStella Laurenzo 
38436c6c9cSStella Laurenzo Returns a Type object or raises a ValueError if the type cannot be parsed.
39436c6c9cSStella Laurenzo 
40436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/#type-system
41436c6c9cSStella Laurenzo )";
42436c6c9cSStella Laurenzo 
43436c6c9cSStella Laurenzo static const char kContextGetFileLocationDocstring[] =
44436c6c9cSStella Laurenzo     R"(Gets a Location representing a file, line and column)";
45436c6c9cSStella Laurenzo 
46436c6c9cSStella Laurenzo static const char kModuleParseDocstring[] =
47436c6c9cSStella Laurenzo     R"(Parses a module's assembly format from a string.
48436c6c9cSStella Laurenzo 
49436c6c9cSStella Laurenzo Returns a new MlirModule or raises a ValueError if the parsing fails.
50436c6c9cSStella Laurenzo 
51436c6c9cSStella Laurenzo See also: https://mlir.llvm.org/docs/LangRef/
52436c6c9cSStella Laurenzo )";
53436c6c9cSStella Laurenzo 
54436c6c9cSStella Laurenzo static const char kOperationCreateDocstring[] =
55436c6c9cSStella Laurenzo     R"(Creates a new operation.
56436c6c9cSStella Laurenzo 
57436c6c9cSStella Laurenzo Args:
58436c6c9cSStella Laurenzo   name: Operation name (e.g. "dialect.operation").
59436c6c9cSStella Laurenzo   results: Sequence of Type representing op result types.
60436c6c9cSStella Laurenzo   attributes: Dict of str:Attribute.
61436c6c9cSStella Laurenzo   successors: List of Block for the operation's successors.
62436c6c9cSStella Laurenzo   regions: Number of regions to create.
63436c6c9cSStella Laurenzo   location: A Location object (defaults to resolve from context manager).
64436c6c9cSStella Laurenzo   ip: An InsertionPoint (defaults to resolve from context manager or set to
65436c6c9cSStella Laurenzo     False to disable insertion, even with an insertion point set in the
66436c6c9cSStella Laurenzo     context manager).
67436c6c9cSStella Laurenzo Returns:
68436c6c9cSStella Laurenzo   A new "detached" Operation object. Detached operations can be added
69436c6c9cSStella Laurenzo   to blocks, which causes them to become "attached."
70436c6c9cSStella Laurenzo )";
71436c6c9cSStella Laurenzo 
72436c6c9cSStella Laurenzo static const char kOperationPrintDocstring[] =
73436c6c9cSStella Laurenzo     R"(Prints the assembly form of the operation to a file like object.
74436c6c9cSStella Laurenzo 
75436c6c9cSStella Laurenzo Args:
76436c6c9cSStella Laurenzo   file: The file like object to write to. Defaults to sys.stdout.
77436c6c9cSStella Laurenzo   binary: Whether to write bytes (True) or str (False). Defaults to False.
78436c6c9cSStella Laurenzo   large_elements_limit: Whether to elide elements attributes above this
79436c6c9cSStella Laurenzo     number of elements. Defaults to None (no limit).
80436c6c9cSStella Laurenzo   enable_debug_info: Whether to print debug/location information. Defaults
81436c6c9cSStella Laurenzo     to False.
82436c6c9cSStella Laurenzo   pretty_debug_info: Whether to format debug information for easier reading
83436c6c9cSStella Laurenzo     by a human (warning: the result is unparseable).
84436c6c9cSStella Laurenzo   print_generic_op_form: Whether to print the generic assembly forms of all
85436c6c9cSStella Laurenzo     ops. Defaults to False.
86436c6c9cSStella Laurenzo   use_local_Scope: Whether to print in a way that is more optimized for
87436c6c9cSStella Laurenzo     multi-threaded access but may not be consistent with how the overall
88436c6c9cSStella Laurenzo     module prints.
89436c6c9cSStella Laurenzo )";
90436c6c9cSStella Laurenzo 
91436c6c9cSStella Laurenzo static const char kOperationGetAsmDocstring[] =
92436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with all options available.
93436c6c9cSStella Laurenzo 
94436c6c9cSStella Laurenzo Args:
95436c6c9cSStella Laurenzo   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96436c6c9cSStella Laurenzo     False.
97436c6c9cSStella Laurenzo   ... others ...: See the print() method for common keyword arguments for
98436c6c9cSStella Laurenzo     configuring the printout.
99436c6c9cSStella Laurenzo Returns:
100436c6c9cSStella Laurenzo   Either a bytes or str object, depending on the setting of the 'binary'
101436c6c9cSStella Laurenzo   argument.
102436c6c9cSStella Laurenzo )";
103436c6c9cSStella Laurenzo 
104436c6c9cSStella Laurenzo static const char kOperationStrDunderDocstring[] =
105436c6c9cSStella Laurenzo     R"(Gets the assembly form of the operation with default options.
106436c6c9cSStella Laurenzo 
107436c6c9cSStella Laurenzo If more advanced control over the assembly formatting or I/O options is needed,
108436c6c9cSStella Laurenzo use the dedicated print or get_asm method, which supports keyword arguments to
109436c6c9cSStella Laurenzo customize behavior.
110436c6c9cSStella Laurenzo )";
111436c6c9cSStella Laurenzo 
112436c6c9cSStella Laurenzo static const char kDumpDocstring[] =
113436c6c9cSStella Laurenzo     R"(Dumps a debug representation of the object to stderr.)";
114436c6c9cSStella Laurenzo 
115436c6c9cSStella Laurenzo static const char kAppendBlockDocstring[] =
116436c6c9cSStella Laurenzo     R"(Appends a new block, with argument types as positional args.
117436c6c9cSStella Laurenzo 
118436c6c9cSStella Laurenzo Returns:
119436c6c9cSStella Laurenzo   The created block.
120436c6c9cSStella Laurenzo )";
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo static const char kValueDunderStrDocstring[] =
123436c6c9cSStella Laurenzo     R"(Returns the string form of the value.
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo If the value is a block argument, this is the assembly form of its type and the
126436c6c9cSStella Laurenzo position in the argument list. If the value is an operation result, this is
127436c6c9cSStella Laurenzo equivalent to printing the operation that produced it.
128436c6c9cSStella Laurenzo )";
129436c6c9cSStella Laurenzo 
130436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
131436c6c9cSStella Laurenzo // Utilities.
132436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
133436c6c9cSStella Laurenzo 
1344acd8457SAlex Zinenko /// Helper for creating an @classmethod.
135436c6c9cSStella Laurenzo template <class Func, typename... Args>
136436c6c9cSStella Laurenzo py::object classmethod(Func f, Args... args) {
137436c6c9cSStella Laurenzo   py::object cf = py::cpp_function(f, args...);
138436c6c9cSStella Laurenzo   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139436c6c9cSStella Laurenzo }
140436c6c9cSStella Laurenzo 
141436c6c9cSStella Laurenzo static py::object
142436c6c9cSStella Laurenzo createCustomDialectWrapper(const std::string &dialectNamespace,
143436c6c9cSStella Laurenzo                            py::object dialectDescriptor) {
144436c6c9cSStella Laurenzo   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145436c6c9cSStella Laurenzo   if (!dialectClass) {
146436c6c9cSStella Laurenzo     // Use the base class.
147436c6c9cSStella Laurenzo     return py::cast(PyDialect(std::move(dialectDescriptor)));
148436c6c9cSStella Laurenzo   }
149436c6c9cSStella Laurenzo 
150436c6c9cSStella Laurenzo   // Create the custom implementation.
151436c6c9cSStella Laurenzo   return (*dialectClass)(std::move(dialectDescriptor));
152436c6c9cSStella Laurenzo }
153436c6c9cSStella Laurenzo 
154436c6c9cSStella Laurenzo static MlirStringRef toMlirStringRef(const std::string &s) {
155436c6c9cSStella Laurenzo   return mlirStringRefCreate(s.data(), s.size());
156436c6c9cSStella Laurenzo }
157436c6c9cSStella Laurenzo 
1584acd8457SAlex Zinenko /// Wrapper for the global LLVM debugging flag.
1594acd8457SAlex Zinenko struct PyGlobalDebugFlag {
1604acd8457SAlex Zinenko   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
1614acd8457SAlex Zinenko 
1624acd8457SAlex Zinenko   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
1634acd8457SAlex Zinenko 
1644acd8457SAlex Zinenko   static void bind(py::module &m) {
1654acd8457SAlex Zinenko     // Debug flags.
166f05ff4f7SStella Laurenzo     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
1674acd8457SAlex Zinenko         .def_property_static("flag", &PyGlobalDebugFlag::get,
1684acd8457SAlex Zinenko                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
1694acd8457SAlex Zinenko   }
1704acd8457SAlex Zinenko };
1714acd8457SAlex Zinenko 
172436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
173436c6c9cSStella Laurenzo // Collections.
174436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
175436c6c9cSStella Laurenzo 
176436c6c9cSStella Laurenzo namespace {
177436c6c9cSStella Laurenzo 
178436c6c9cSStella Laurenzo class PyRegionIterator {
179436c6c9cSStella Laurenzo public:
180436c6c9cSStella Laurenzo   PyRegionIterator(PyOperationRef operation)
181436c6c9cSStella Laurenzo       : operation(std::move(operation)) {}
182436c6c9cSStella Laurenzo 
183436c6c9cSStella Laurenzo   PyRegionIterator &dunderIter() { return *this; }
184436c6c9cSStella Laurenzo 
185436c6c9cSStella Laurenzo   PyRegion dunderNext() {
186436c6c9cSStella Laurenzo     operation->checkValid();
187436c6c9cSStella Laurenzo     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188436c6c9cSStella Laurenzo       throw py::stop_iteration();
189436c6c9cSStella Laurenzo     }
190436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191436c6c9cSStella Laurenzo     return PyRegion(operation, region);
192436c6c9cSStella Laurenzo   }
193436c6c9cSStella Laurenzo 
194436c6c9cSStella Laurenzo   static void bind(py::module &m) {
195f05ff4f7SStella Laurenzo     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
196436c6c9cSStella Laurenzo         .def("__iter__", &PyRegionIterator::dunderIter)
197436c6c9cSStella Laurenzo         .def("__next__", &PyRegionIterator::dunderNext);
198436c6c9cSStella Laurenzo   }
199436c6c9cSStella Laurenzo 
200436c6c9cSStella Laurenzo private:
201436c6c9cSStella Laurenzo   PyOperationRef operation;
202436c6c9cSStella Laurenzo   int nextIndex = 0;
203436c6c9cSStella Laurenzo };
204436c6c9cSStella Laurenzo 
205436c6c9cSStella Laurenzo /// Regions of an op are fixed length and indexed numerically so are represented
206436c6c9cSStella Laurenzo /// with a sequence-like container.
207436c6c9cSStella Laurenzo class PyRegionList {
208436c6c9cSStella Laurenzo public:
209436c6c9cSStella Laurenzo   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210436c6c9cSStella Laurenzo 
211436c6c9cSStella Laurenzo   intptr_t dunderLen() {
212436c6c9cSStella Laurenzo     operation->checkValid();
213436c6c9cSStella Laurenzo     return mlirOperationGetNumRegions(operation->get());
214436c6c9cSStella Laurenzo   }
215436c6c9cSStella Laurenzo 
216436c6c9cSStella Laurenzo   PyRegion dunderGetItem(intptr_t index) {
217436c6c9cSStella Laurenzo     // dunderLen checks validity.
218436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
219436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
220436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
221436c6c9cSStella Laurenzo     }
222436c6c9cSStella Laurenzo     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223436c6c9cSStella Laurenzo     return PyRegion(operation, region);
224436c6c9cSStella Laurenzo   }
225436c6c9cSStella Laurenzo 
226436c6c9cSStella Laurenzo   static void bind(py::module &m) {
227f05ff4f7SStella Laurenzo     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
228436c6c9cSStella Laurenzo         .def("__len__", &PyRegionList::dunderLen)
229436c6c9cSStella Laurenzo         .def("__getitem__", &PyRegionList::dunderGetItem);
230436c6c9cSStella Laurenzo   }
231436c6c9cSStella Laurenzo 
232436c6c9cSStella Laurenzo private:
233436c6c9cSStella Laurenzo   PyOperationRef operation;
234436c6c9cSStella Laurenzo };
235436c6c9cSStella Laurenzo 
236436c6c9cSStella Laurenzo class PyBlockIterator {
237436c6c9cSStella Laurenzo public:
238436c6c9cSStella Laurenzo   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239436c6c9cSStella Laurenzo       : operation(std::move(operation)), next(next) {}
240436c6c9cSStella Laurenzo 
241436c6c9cSStella Laurenzo   PyBlockIterator &dunderIter() { return *this; }
242436c6c9cSStella Laurenzo 
243436c6c9cSStella Laurenzo   PyBlock dunderNext() {
244436c6c9cSStella Laurenzo     operation->checkValid();
245436c6c9cSStella Laurenzo     if (mlirBlockIsNull(next)) {
246436c6c9cSStella Laurenzo       throw py::stop_iteration();
247436c6c9cSStella Laurenzo     }
248436c6c9cSStella Laurenzo 
249436c6c9cSStella Laurenzo     PyBlock returnBlock(operation, next);
250436c6c9cSStella Laurenzo     next = mlirBlockGetNextInRegion(next);
251436c6c9cSStella Laurenzo     return returnBlock;
252436c6c9cSStella Laurenzo   }
253436c6c9cSStella Laurenzo 
254436c6c9cSStella Laurenzo   static void bind(py::module &m) {
255f05ff4f7SStella Laurenzo     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
256436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockIterator::dunderIter)
257436c6c9cSStella Laurenzo         .def("__next__", &PyBlockIterator::dunderNext);
258436c6c9cSStella Laurenzo   }
259436c6c9cSStella Laurenzo 
260436c6c9cSStella Laurenzo private:
261436c6c9cSStella Laurenzo   PyOperationRef operation;
262436c6c9cSStella Laurenzo   MlirBlock next;
263436c6c9cSStella Laurenzo };
264436c6c9cSStella Laurenzo 
265436c6c9cSStella Laurenzo /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266436c6c9cSStella Laurenzo /// we present them as a more full-featured list-like container but optimize
267436c6c9cSStella Laurenzo /// it for forward iteration. Blocks are always owned by a region.
268436c6c9cSStella Laurenzo class PyBlockList {
269436c6c9cSStella Laurenzo public:
270436c6c9cSStella Laurenzo   PyBlockList(PyOperationRef operation, MlirRegion region)
271436c6c9cSStella Laurenzo       : operation(std::move(operation)), region(region) {}
272436c6c9cSStella Laurenzo 
273436c6c9cSStella Laurenzo   PyBlockIterator dunderIter() {
274436c6c9cSStella Laurenzo     operation->checkValid();
275436c6c9cSStella Laurenzo     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276436c6c9cSStella Laurenzo   }
277436c6c9cSStella Laurenzo 
278436c6c9cSStella Laurenzo   intptr_t dunderLen() {
279436c6c9cSStella Laurenzo     operation->checkValid();
280436c6c9cSStella Laurenzo     intptr_t count = 0;
281436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
282436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
283436c6c9cSStella Laurenzo       count += 1;
284436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
285436c6c9cSStella Laurenzo     }
286436c6c9cSStella Laurenzo     return count;
287436c6c9cSStella Laurenzo   }
288436c6c9cSStella Laurenzo 
289436c6c9cSStella Laurenzo   PyBlock dunderGetItem(intptr_t index) {
290436c6c9cSStella Laurenzo     operation->checkValid();
291436c6c9cSStella Laurenzo     if (index < 0) {
292436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
293436c6c9cSStella Laurenzo                        "attempt to access out of bounds block");
294436c6c9cSStella Laurenzo     }
295436c6c9cSStella Laurenzo     MlirBlock block = mlirRegionGetFirstBlock(region);
296436c6c9cSStella Laurenzo     while (!mlirBlockIsNull(block)) {
297436c6c9cSStella Laurenzo       if (index == 0) {
298436c6c9cSStella Laurenzo         return PyBlock(operation, block);
299436c6c9cSStella Laurenzo       }
300436c6c9cSStella Laurenzo       block = mlirBlockGetNextInRegion(block);
301436c6c9cSStella Laurenzo       index -= 1;
302436c6c9cSStella Laurenzo     }
303436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304436c6c9cSStella Laurenzo   }
305436c6c9cSStella Laurenzo 
306436c6c9cSStella Laurenzo   PyBlock appendBlock(py::args pyArgTypes) {
307436c6c9cSStella Laurenzo     operation->checkValid();
308436c6c9cSStella Laurenzo     llvm::SmallVector<MlirType, 4> argTypes;
309436c6c9cSStella Laurenzo     argTypes.reserve(pyArgTypes.size());
310436c6c9cSStella Laurenzo     for (auto &pyArg : pyArgTypes) {
311436c6c9cSStella Laurenzo       argTypes.push_back(pyArg.cast<PyType &>());
312436c6c9cSStella Laurenzo     }
313436c6c9cSStella Laurenzo 
314436c6c9cSStella Laurenzo     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315436c6c9cSStella Laurenzo     mlirRegionAppendOwnedBlock(region, block);
316436c6c9cSStella Laurenzo     return PyBlock(operation, block);
317436c6c9cSStella Laurenzo   }
318436c6c9cSStella Laurenzo 
319436c6c9cSStella Laurenzo   static void bind(py::module &m) {
320f05ff4f7SStella Laurenzo     py::class_<PyBlockList>(m, "BlockList", py::module_local())
321436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockList::dunderGetItem)
322436c6c9cSStella Laurenzo         .def("__iter__", &PyBlockList::dunderIter)
323436c6c9cSStella Laurenzo         .def("__len__", &PyBlockList::dunderLen)
324436c6c9cSStella Laurenzo         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325436c6c9cSStella Laurenzo   }
326436c6c9cSStella Laurenzo 
327436c6c9cSStella Laurenzo private:
328436c6c9cSStella Laurenzo   PyOperationRef operation;
329436c6c9cSStella Laurenzo   MlirRegion region;
330436c6c9cSStella Laurenzo };
331436c6c9cSStella Laurenzo 
332436c6c9cSStella Laurenzo class PyOperationIterator {
333436c6c9cSStella Laurenzo public:
334436c6c9cSStella Laurenzo   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), next(next) {}
336436c6c9cSStella Laurenzo 
337436c6c9cSStella Laurenzo   PyOperationIterator &dunderIter() { return *this; }
338436c6c9cSStella Laurenzo 
339436c6c9cSStella Laurenzo   py::object dunderNext() {
340436c6c9cSStella Laurenzo     parentOperation->checkValid();
341436c6c9cSStella Laurenzo     if (mlirOperationIsNull(next)) {
342436c6c9cSStella Laurenzo       throw py::stop_iteration();
343436c6c9cSStella Laurenzo     }
344436c6c9cSStella Laurenzo 
345436c6c9cSStella Laurenzo     PyOperationRef returnOperation =
346436c6c9cSStella Laurenzo         PyOperation::forOperation(parentOperation->getContext(), next);
347436c6c9cSStella Laurenzo     next = mlirOperationGetNextInBlock(next);
348436c6c9cSStella Laurenzo     return returnOperation->createOpView();
349436c6c9cSStella Laurenzo   }
350436c6c9cSStella Laurenzo 
351436c6c9cSStella Laurenzo   static void bind(py::module &m) {
352f05ff4f7SStella Laurenzo     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
353436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationIterator::dunderIter)
354436c6c9cSStella Laurenzo         .def("__next__", &PyOperationIterator::dunderNext);
355436c6c9cSStella Laurenzo   }
356436c6c9cSStella Laurenzo 
357436c6c9cSStella Laurenzo private:
358436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
359436c6c9cSStella Laurenzo   MlirOperation next;
360436c6c9cSStella Laurenzo };
361436c6c9cSStella Laurenzo 
362436c6c9cSStella Laurenzo /// Operations are exposed by the C-API as a forward-only linked list. In
363436c6c9cSStella Laurenzo /// Python, we present them as a more full-featured list-like container but
364436c6c9cSStella Laurenzo /// optimize it for forward iteration. Iterable operations are always owned
365436c6c9cSStella Laurenzo /// by a block.
366436c6c9cSStella Laurenzo class PyOperationList {
367436c6c9cSStella Laurenzo public:
368436c6c9cSStella Laurenzo   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {}
370436c6c9cSStella Laurenzo 
371436c6c9cSStella Laurenzo   PyOperationIterator dunderIter() {
372436c6c9cSStella Laurenzo     parentOperation->checkValid();
373436c6c9cSStella Laurenzo     return PyOperationIterator(parentOperation,
374436c6c9cSStella Laurenzo                                mlirBlockGetFirstOperation(block));
375436c6c9cSStella Laurenzo   }
376436c6c9cSStella Laurenzo 
377436c6c9cSStella Laurenzo   intptr_t dunderLen() {
378436c6c9cSStella Laurenzo     parentOperation->checkValid();
379436c6c9cSStella Laurenzo     intptr_t count = 0;
380436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
382436c6c9cSStella Laurenzo       count += 1;
383436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
384436c6c9cSStella Laurenzo     }
385436c6c9cSStella Laurenzo     return count;
386436c6c9cSStella Laurenzo   }
387436c6c9cSStella Laurenzo 
388436c6c9cSStella Laurenzo   py::object dunderGetItem(intptr_t index) {
389436c6c9cSStella Laurenzo     parentOperation->checkValid();
390436c6c9cSStella Laurenzo     if (index < 0) {
391436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
392436c6c9cSStella Laurenzo                        "attempt to access out of bounds operation");
393436c6c9cSStella Laurenzo     }
394436c6c9cSStella Laurenzo     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395436c6c9cSStella Laurenzo     while (!mlirOperationIsNull(childOp)) {
396436c6c9cSStella Laurenzo       if (index == 0) {
397436c6c9cSStella Laurenzo         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398436c6c9cSStella Laurenzo             ->createOpView();
399436c6c9cSStella Laurenzo       }
400436c6c9cSStella Laurenzo       childOp = mlirOperationGetNextInBlock(childOp);
401436c6c9cSStella Laurenzo       index -= 1;
402436c6c9cSStella Laurenzo     }
403436c6c9cSStella Laurenzo     throw SetPyError(PyExc_IndexError,
404436c6c9cSStella Laurenzo                      "attempt to access out of bounds operation");
405436c6c9cSStella Laurenzo   }
406436c6c9cSStella Laurenzo 
407436c6c9cSStella Laurenzo   static void bind(py::module &m) {
408f05ff4f7SStella Laurenzo     py::class_<PyOperationList>(m, "OperationList", py::module_local())
409436c6c9cSStella Laurenzo         .def("__getitem__", &PyOperationList::dunderGetItem)
410436c6c9cSStella Laurenzo         .def("__iter__", &PyOperationList::dunderIter)
411436c6c9cSStella Laurenzo         .def("__len__", &PyOperationList::dunderLen);
412436c6c9cSStella Laurenzo   }
413436c6c9cSStella Laurenzo 
414436c6c9cSStella Laurenzo private:
415436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
416436c6c9cSStella Laurenzo   MlirBlock block;
417436c6c9cSStella Laurenzo };
418436c6c9cSStella Laurenzo 
419436c6c9cSStella Laurenzo } // namespace
420436c6c9cSStella Laurenzo 
421436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
422436c6c9cSStella Laurenzo // PyMlirContext
423436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
424436c6c9cSStella Laurenzo 
425436c6c9cSStella Laurenzo PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
427436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
428436c6c9cSStella Laurenzo   liveContexts[context.ptr] = this;
429436c6c9cSStella Laurenzo }
430436c6c9cSStella Laurenzo 
431436c6c9cSStella Laurenzo PyMlirContext::~PyMlirContext() {
432436c6c9cSStella Laurenzo   // Note that the only public way to construct an instance is via the
433436c6c9cSStella Laurenzo   // forContext method, which always puts the associated handle into
434436c6c9cSStella Laurenzo   // liveContexts.
435436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
436436c6c9cSStella Laurenzo   getLiveContexts().erase(context.ptr);
437436c6c9cSStella Laurenzo   mlirContextDestroy(context);
438436c6c9cSStella Laurenzo }
439436c6c9cSStella Laurenzo 
440436c6c9cSStella Laurenzo py::object PyMlirContext::getCapsule() {
441436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442436c6c9cSStella Laurenzo }
443436c6c9cSStella Laurenzo 
444436c6c9cSStella Laurenzo py::object PyMlirContext::createFromCapsule(py::object capsule) {
445436c6c9cSStella Laurenzo   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446436c6c9cSStella Laurenzo   if (mlirContextIsNull(rawContext))
447436c6c9cSStella Laurenzo     throw py::error_already_set();
448436c6c9cSStella Laurenzo   return forContext(rawContext).releaseObject();
449436c6c9cSStella Laurenzo }
450436c6c9cSStella Laurenzo 
451436c6c9cSStella Laurenzo PyMlirContext *PyMlirContext::createNewContextForInit() {
452436c6c9cSStella Laurenzo   MlirContext context = mlirContextCreate();
453436c6c9cSStella Laurenzo   mlirRegisterAllDialects(context);
454436c6c9cSStella Laurenzo   return new PyMlirContext(context);
455436c6c9cSStella Laurenzo }
456436c6c9cSStella Laurenzo 
457436c6c9cSStella Laurenzo PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
459436c6c9cSStella Laurenzo   auto &liveContexts = getLiveContexts();
460436c6c9cSStella Laurenzo   auto it = liveContexts.find(context.ptr);
461436c6c9cSStella Laurenzo   if (it == liveContexts.end()) {
462436c6c9cSStella Laurenzo     // Create.
463436c6c9cSStella Laurenzo     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464436c6c9cSStella Laurenzo     py::object pyRef = py::cast(unownedContextWrapper);
465436c6c9cSStella Laurenzo     assert(pyRef && "cast to py::object failed");
466436c6c9cSStella Laurenzo     liveContexts[context.ptr] = unownedContextWrapper;
467436c6c9cSStella Laurenzo     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468436c6c9cSStella Laurenzo   }
469436c6c9cSStella Laurenzo   // Use existing.
470436c6c9cSStella Laurenzo   py::object pyRef = py::cast(it->second);
471436c6c9cSStella Laurenzo   return PyMlirContextRef(it->second, std::move(pyRef));
472436c6c9cSStella Laurenzo }
473436c6c9cSStella Laurenzo 
474436c6c9cSStella Laurenzo PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475436c6c9cSStella Laurenzo   static LiveContextMap liveContexts;
476436c6c9cSStella Laurenzo   return liveContexts;
477436c6c9cSStella Laurenzo }
478436c6c9cSStella Laurenzo 
479436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480436c6c9cSStella Laurenzo 
481436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482436c6c9cSStella Laurenzo 
483436c6c9cSStella Laurenzo size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484436c6c9cSStella Laurenzo 
485436c6c9cSStella Laurenzo pybind11::object PyMlirContext::contextEnter() {
486436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushContext(*this);
487436c6c9cSStella Laurenzo }
488436c6c9cSStella Laurenzo 
489436c6c9cSStella Laurenzo void PyMlirContext::contextExit(pybind11::object excType,
490436c6c9cSStella Laurenzo                                 pybind11::object excVal,
491436c6c9cSStella Laurenzo                                 pybind11::object excTb) {
492436c6c9cSStella Laurenzo   PyThreadContextEntry::popContext(*this);
493436c6c9cSStella Laurenzo }
494436c6c9cSStella Laurenzo 
495436c6c9cSStella Laurenzo PyMlirContext &DefaultingPyMlirContext::resolve() {
496436c6c9cSStella Laurenzo   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497436c6c9cSStella Laurenzo   if (!context) {
498436c6c9cSStella Laurenzo     throw SetPyError(
499436c6c9cSStella Laurenzo         PyExc_RuntimeError,
500436c6c9cSStella Laurenzo         "An MLIR function requires a Context but none was provided in the call "
501436c6c9cSStella Laurenzo         "or from the surrounding environment. Either pass to the function with "
502436c6c9cSStella Laurenzo         "a 'context=' argument or establish a default using 'with Context():'");
503436c6c9cSStella Laurenzo   }
504436c6c9cSStella Laurenzo   return *context;
505436c6c9cSStella Laurenzo }
506436c6c9cSStella Laurenzo 
507436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
508436c6c9cSStella Laurenzo // PyThreadContextEntry management
509436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
510436c6c9cSStella Laurenzo 
511436c6c9cSStella Laurenzo std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512436c6c9cSStella Laurenzo   static thread_local std::vector<PyThreadContextEntry> stack;
513436c6c9cSStella Laurenzo   return stack;
514436c6c9cSStella Laurenzo }
515436c6c9cSStella Laurenzo 
516436c6c9cSStella Laurenzo PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517436c6c9cSStella Laurenzo   auto &stack = getStack();
518436c6c9cSStella Laurenzo   if (stack.empty())
519436c6c9cSStella Laurenzo     return nullptr;
520436c6c9cSStella Laurenzo   return &stack.back();
521436c6c9cSStella Laurenzo }
522436c6c9cSStella Laurenzo 
523436c6c9cSStella Laurenzo void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524436c6c9cSStella Laurenzo                                 py::object insertionPoint,
525436c6c9cSStella Laurenzo                                 py::object location) {
526436c6c9cSStella Laurenzo   auto &stack = getStack();
527436c6c9cSStella Laurenzo   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528436c6c9cSStella Laurenzo                      std::move(location));
529436c6c9cSStella Laurenzo   // If the new stack has more than one entry and the context of the new top
530436c6c9cSStella Laurenzo   // entry matches the previous, copy the insertionPoint and location from the
531436c6c9cSStella Laurenzo   // previous entry if missing from the new top entry.
532436c6c9cSStella Laurenzo   if (stack.size() > 1) {
533436c6c9cSStella Laurenzo     auto &prev = *(stack.rbegin() + 1);
534436c6c9cSStella Laurenzo     auto &current = stack.back();
535436c6c9cSStella Laurenzo     if (current.context.is(prev.context)) {
536436c6c9cSStella Laurenzo       // Default non-context objects from the previous entry.
537436c6c9cSStella Laurenzo       if (!current.insertionPoint)
538436c6c9cSStella Laurenzo         current.insertionPoint = prev.insertionPoint;
539436c6c9cSStella Laurenzo       if (!current.location)
540436c6c9cSStella Laurenzo         current.location = prev.location;
541436c6c9cSStella Laurenzo     }
542436c6c9cSStella Laurenzo   }
543436c6c9cSStella Laurenzo }
544436c6c9cSStella Laurenzo 
545436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getContext() {
546436c6c9cSStella Laurenzo   if (!context)
547436c6c9cSStella Laurenzo     return nullptr;
548436c6c9cSStella Laurenzo   return py::cast<PyMlirContext *>(context);
549436c6c9cSStella Laurenzo }
550436c6c9cSStella Laurenzo 
551436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552436c6c9cSStella Laurenzo   if (!insertionPoint)
553436c6c9cSStella Laurenzo     return nullptr;
554436c6c9cSStella Laurenzo   return py::cast<PyInsertionPoint *>(insertionPoint);
555436c6c9cSStella Laurenzo }
556436c6c9cSStella Laurenzo 
557436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getLocation() {
558436c6c9cSStella Laurenzo   if (!location)
559436c6c9cSStella Laurenzo     return nullptr;
560436c6c9cSStella Laurenzo   return py::cast<PyLocation *>(location);
561436c6c9cSStella Laurenzo }
562436c6c9cSStella Laurenzo 
563436c6c9cSStella Laurenzo PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
565436c6c9cSStella Laurenzo   return tos ? tos->getContext() : nullptr;
566436c6c9cSStella Laurenzo }
567436c6c9cSStella Laurenzo 
568436c6c9cSStella Laurenzo PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
570436c6c9cSStella Laurenzo   return tos ? tos->getInsertionPoint() : nullptr;
571436c6c9cSStella Laurenzo }
572436c6c9cSStella Laurenzo 
573436c6c9cSStella Laurenzo PyLocation *PyThreadContextEntry::getDefaultLocation() {
574436c6c9cSStella Laurenzo   auto *tos = getTopOfStack();
575436c6c9cSStella Laurenzo   return tos ? tos->getLocation() : nullptr;
576436c6c9cSStella Laurenzo }
577436c6c9cSStella Laurenzo 
578436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579436c6c9cSStella Laurenzo   py::object contextObj = py::cast(context);
580436c6c9cSStella Laurenzo   push(FrameKind::Context, /*context=*/contextObj,
581436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
582436c6c9cSStella Laurenzo        /*location=*/py::object());
583436c6c9cSStella Laurenzo   return contextObj;
584436c6c9cSStella Laurenzo }
585436c6c9cSStella Laurenzo 
586436c6c9cSStella Laurenzo void PyThreadContextEntry::popContext(PyMlirContext &context) {
587436c6c9cSStella Laurenzo   auto &stack = getStack();
588436c6c9cSStella Laurenzo   if (stack.empty())
589436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590436c6c9cSStella Laurenzo   auto &tos = stack.back();
591436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593436c6c9cSStella Laurenzo   stack.pop_back();
594436c6c9cSStella Laurenzo }
595436c6c9cSStella Laurenzo 
596436c6c9cSStella Laurenzo py::object
597436c6c9cSStella Laurenzo PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598436c6c9cSStella Laurenzo   py::object contextObj =
599436c6c9cSStella Laurenzo       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600436c6c9cSStella Laurenzo   py::object insertionPointObj = py::cast(insertionPoint);
601436c6c9cSStella Laurenzo   push(FrameKind::InsertionPoint,
602436c6c9cSStella Laurenzo        /*context=*/contextObj,
603436c6c9cSStella Laurenzo        /*insertionPoint=*/insertionPointObj,
604436c6c9cSStella Laurenzo        /*location=*/py::object());
605436c6c9cSStella Laurenzo   return insertionPointObj;
606436c6c9cSStella Laurenzo }
607436c6c9cSStella Laurenzo 
608436c6c9cSStella Laurenzo void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609436c6c9cSStella Laurenzo   auto &stack = getStack();
610436c6c9cSStella Laurenzo   if (stack.empty())
611436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
612436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
613436c6c9cSStella Laurenzo   auto &tos = stack.back();
614436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::InsertionPoint &&
615436c6c9cSStella Laurenzo       tos.getInsertionPoint() != &insertionPoint)
616436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError,
617436c6c9cSStella Laurenzo                      "Unbalanced InsertionPoint enter/exit");
618436c6c9cSStella Laurenzo   stack.pop_back();
619436c6c9cSStella Laurenzo }
620436c6c9cSStella Laurenzo 
621436c6c9cSStella Laurenzo py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622436c6c9cSStella Laurenzo   py::object contextObj = location.getContext().getObject();
623436c6c9cSStella Laurenzo   py::object locationObj = py::cast(location);
624436c6c9cSStella Laurenzo   push(FrameKind::Location, /*context=*/contextObj,
625436c6c9cSStella Laurenzo        /*insertionPoint=*/py::object(),
626436c6c9cSStella Laurenzo        /*location=*/locationObj);
627436c6c9cSStella Laurenzo   return locationObj;
628436c6c9cSStella Laurenzo }
629436c6c9cSStella Laurenzo 
630436c6c9cSStella Laurenzo void PyThreadContextEntry::popLocation(PyLocation &location) {
631436c6c9cSStella Laurenzo   auto &stack = getStack();
632436c6c9cSStella Laurenzo   if (stack.empty())
633436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634436c6c9cSStella Laurenzo   auto &tos = stack.back();
635436c6c9cSStella Laurenzo   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637436c6c9cSStella Laurenzo   stack.pop_back();
638436c6c9cSStella Laurenzo }
639436c6c9cSStella Laurenzo 
640436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
641436c6c9cSStella Laurenzo // PyDialect, PyDialectDescriptor, PyDialects
642436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
643436c6c9cSStella Laurenzo 
644436c6c9cSStella Laurenzo MlirDialect PyDialects::getDialectForKey(const std::string &key,
645436c6c9cSStella Laurenzo                                          bool attrError) {
646f8479d9dSRiver Riddle   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
647f8479d9dSRiver Riddle                                                     {key.data(), key.size()});
648436c6c9cSStella Laurenzo   if (mlirDialectIsNull(dialect)) {
649436c6c9cSStella Laurenzo     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
650436c6c9cSStella Laurenzo                      Twine("Dialect '") + key + "' not found");
651436c6c9cSStella Laurenzo   }
652436c6c9cSStella Laurenzo   return dialect;
653436c6c9cSStella Laurenzo }
654436c6c9cSStella Laurenzo 
655436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
656436c6c9cSStella Laurenzo // PyLocation
657436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
658436c6c9cSStella Laurenzo 
659436c6c9cSStella Laurenzo py::object PyLocation::getCapsule() {
660436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
661436c6c9cSStella Laurenzo }
662436c6c9cSStella Laurenzo 
663436c6c9cSStella Laurenzo PyLocation PyLocation::createFromCapsule(py::object capsule) {
664436c6c9cSStella Laurenzo   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
665436c6c9cSStella Laurenzo   if (mlirLocationIsNull(rawLoc))
666436c6c9cSStella Laurenzo     throw py::error_already_set();
667436c6c9cSStella Laurenzo   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
668436c6c9cSStella Laurenzo                     rawLoc);
669436c6c9cSStella Laurenzo }
670436c6c9cSStella Laurenzo 
671436c6c9cSStella Laurenzo py::object PyLocation::contextEnter() {
672436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushLocation(*this);
673436c6c9cSStella Laurenzo }
674436c6c9cSStella Laurenzo 
675436c6c9cSStella Laurenzo void PyLocation::contextExit(py::object excType, py::object excVal,
676436c6c9cSStella Laurenzo                              py::object excTb) {
677436c6c9cSStella Laurenzo   PyThreadContextEntry::popLocation(*this);
678436c6c9cSStella Laurenzo }
679436c6c9cSStella Laurenzo 
680436c6c9cSStella Laurenzo PyLocation &DefaultingPyLocation::resolve() {
681436c6c9cSStella Laurenzo   auto *location = PyThreadContextEntry::getDefaultLocation();
682436c6c9cSStella Laurenzo   if (!location) {
683436c6c9cSStella Laurenzo     throw SetPyError(
684436c6c9cSStella Laurenzo         PyExc_RuntimeError,
685436c6c9cSStella Laurenzo         "An MLIR function requires a Location but none was provided in the "
686436c6c9cSStella Laurenzo         "call or from the surrounding environment. Either pass to the function "
687436c6c9cSStella Laurenzo         "with a 'loc=' argument or establish a default using 'with loc:'");
688436c6c9cSStella Laurenzo   }
689436c6c9cSStella Laurenzo   return *location;
690436c6c9cSStella Laurenzo }
691436c6c9cSStella Laurenzo 
692436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
693436c6c9cSStella Laurenzo // PyModule
694436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
695436c6c9cSStella Laurenzo 
696436c6c9cSStella Laurenzo PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
697436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), module(module) {}
698436c6c9cSStella Laurenzo 
699436c6c9cSStella Laurenzo PyModule::~PyModule() {
700436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
701436c6c9cSStella Laurenzo   auto &liveModules = getContext()->liveModules;
702436c6c9cSStella Laurenzo   assert(liveModules.count(module.ptr) == 1 &&
703436c6c9cSStella Laurenzo          "destroying module not in live map");
704436c6c9cSStella Laurenzo   liveModules.erase(module.ptr);
705436c6c9cSStella Laurenzo   mlirModuleDestroy(module);
706436c6c9cSStella Laurenzo }
707436c6c9cSStella Laurenzo 
708436c6c9cSStella Laurenzo PyModuleRef PyModule::forModule(MlirModule module) {
709436c6c9cSStella Laurenzo   MlirContext context = mlirModuleGetContext(module);
710436c6c9cSStella Laurenzo   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
711436c6c9cSStella Laurenzo 
712436c6c9cSStella Laurenzo   py::gil_scoped_acquire acquire;
713436c6c9cSStella Laurenzo   auto &liveModules = contextRef->liveModules;
714436c6c9cSStella Laurenzo   auto it = liveModules.find(module.ptr);
715436c6c9cSStella Laurenzo   if (it == liveModules.end()) {
716436c6c9cSStella Laurenzo     // Create.
717436c6c9cSStella Laurenzo     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
718436c6c9cSStella Laurenzo     // Note that the default return value policy on cast is automatic_reference,
719436c6c9cSStella Laurenzo     // which does not take ownership (delete will not be called).
720436c6c9cSStella Laurenzo     // Just be explicit.
721436c6c9cSStella Laurenzo     py::object pyRef =
722436c6c9cSStella Laurenzo         py::cast(unownedModule, py::return_value_policy::take_ownership);
723436c6c9cSStella Laurenzo     unownedModule->handle = pyRef;
724436c6c9cSStella Laurenzo     liveModules[module.ptr] =
725436c6c9cSStella Laurenzo         std::make_pair(unownedModule->handle, unownedModule);
726436c6c9cSStella Laurenzo     return PyModuleRef(unownedModule, std::move(pyRef));
727436c6c9cSStella Laurenzo   }
728436c6c9cSStella Laurenzo   // Use existing.
729436c6c9cSStella Laurenzo   PyModule *existing = it->second.second;
730436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
731436c6c9cSStella Laurenzo   return PyModuleRef(existing, std::move(pyRef));
732436c6c9cSStella Laurenzo }
733436c6c9cSStella Laurenzo 
734436c6c9cSStella Laurenzo py::object PyModule::createFromCapsule(py::object capsule) {
735436c6c9cSStella Laurenzo   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
736436c6c9cSStella Laurenzo   if (mlirModuleIsNull(rawModule))
737436c6c9cSStella Laurenzo     throw py::error_already_set();
738436c6c9cSStella Laurenzo   return forModule(rawModule).releaseObject();
739436c6c9cSStella Laurenzo }
740436c6c9cSStella Laurenzo 
741436c6c9cSStella Laurenzo py::object PyModule::getCapsule() {
742436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
743436c6c9cSStella Laurenzo }
744436c6c9cSStella Laurenzo 
745436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
746436c6c9cSStella Laurenzo // PyOperation
747436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
748436c6c9cSStella Laurenzo 
749436c6c9cSStella Laurenzo PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
750436c6c9cSStella Laurenzo     : BaseContextObject(std::move(contextRef)), operation(operation) {}
751436c6c9cSStella Laurenzo 
752436c6c9cSStella Laurenzo PyOperation::~PyOperation() {
75349745f87SMike Urbach   // If the operation has already been invalidated there is nothing to do.
75449745f87SMike Urbach   if (!valid)
75549745f87SMike Urbach     return;
756436c6c9cSStella Laurenzo   auto &liveOperations = getContext()->liveOperations;
757436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 1 &&
758436c6c9cSStella Laurenzo          "destroying operation not in live map");
759436c6c9cSStella Laurenzo   liveOperations.erase(operation.ptr);
760436c6c9cSStella Laurenzo   if (!isAttached()) {
761436c6c9cSStella Laurenzo     mlirOperationDestroy(operation);
762436c6c9cSStella Laurenzo   }
763436c6c9cSStella Laurenzo }
764436c6c9cSStella Laurenzo 
765436c6c9cSStella Laurenzo PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
766436c6c9cSStella Laurenzo                                            MlirOperation operation,
767436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
768436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
769436c6c9cSStella Laurenzo   // Create.
770436c6c9cSStella Laurenzo   PyOperation *unownedOperation =
771436c6c9cSStella Laurenzo       new PyOperation(std::move(contextRef), operation);
772436c6c9cSStella Laurenzo   // Note that the default return value policy on cast is automatic_reference,
773436c6c9cSStella Laurenzo   // which does not take ownership (delete will not be called).
774436c6c9cSStella Laurenzo   // Just be explicit.
775436c6c9cSStella Laurenzo   py::object pyRef =
776436c6c9cSStella Laurenzo       py::cast(unownedOperation, py::return_value_policy::take_ownership);
777436c6c9cSStella Laurenzo   unownedOperation->handle = pyRef;
778436c6c9cSStella Laurenzo   if (parentKeepAlive) {
779436c6c9cSStella Laurenzo     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
780436c6c9cSStella Laurenzo   }
781436c6c9cSStella Laurenzo   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
782436c6c9cSStella Laurenzo   return PyOperationRef(unownedOperation, std::move(pyRef));
783436c6c9cSStella Laurenzo }
784436c6c9cSStella Laurenzo 
785436c6c9cSStella Laurenzo PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
786436c6c9cSStella Laurenzo                                          MlirOperation operation,
787436c6c9cSStella Laurenzo                                          py::object parentKeepAlive) {
788436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
789436c6c9cSStella Laurenzo   auto it = liveOperations.find(operation.ptr);
790436c6c9cSStella Laurenzo   if (it == liveOperations.end()) {
791436c6c9cSStella Laurenzo     // Create.
792436c6c9cSStella Laurenzo     return createInstance(std::move(contextRef), operation,
793436c6c9cSStella Laurenzo                           std::move(parentKeepAlive));
794436c6c9cSStella Laurenzo   }
795436c6c9cSStella Laurenzo   // Use existing.
796436c6c9cSStella Laurenzo   PyOperation *existing = it->second.second;
797436c6c9cSStella Laurenzo   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
798436c6c9cSStella Laurenzo   return PyOperationRef(existing, std::move(pyRef));
799436c6c9cSStella Laurenzo }
800436c6c9cSStella Laurenzo 
801436c6c9cSStella Laurenzo PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
802436c6c9cSStella Laurenzo                                            MlirOperation operation,
803436c6c9cSStella Laurenzo                                            py::object parentKeepAlive) {
804436c6c9cSStella Laurenzo   auto &liveOperations = contextRef->liveOperations;
805436c6c9cSStella Laurenzo   assert(liveOperations.count(operation.ptr) == 0 &&
806436c6c9cSStella Laurenzo          "cannot create detached operation that already exists");
807436c6c9cSStella Laurenzo   (void)liveOperations;
808436c6c9cSStella Laurenzo 
809436c6c9cSStella Laurenzo   PyOperationRef created = createInstance(std::move(contextRef), operation,
810436c6c9cSStella Laurenzo                                           std::move(parentKeepAlive));
811436c6c9cSStella Laurenzo   created->attached = false;
812436c6c9cSStella Laurenzo   return created;
813436c6c9cSStella Laurenzo }
814436c6c9cSStella Laurenzo 
815436c6c9cSStella Laurenzo void PyOperation::checkValid() const {
816436c6c9cSStella Laurenzo   if (!valid) {
817436c6c9cSStella Laurenzo     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
818436c6c9cSStella Laurenzo   }
819436c6c9cSStella Laurenzo }
820436c6c9cSStella Laurenzo 
821436c6c9cSStella Laurenzo void PyOperationBase::print(py::object fileObject, bool binary,
822436c6c9cSStella Laurenzo                             llvm::Optional<int64_t> largeElementsLimit,
823436c6c9cSStella Laurenzo                             bool enableDebugInfo, bool prettyDebugInfo,
824436c6c9cSStella Laurenzo                             bool printGenericOpForm, bool useLocalScope) {
825436c6c9cSStella Laurenzo   PyOperation &operation = getOperation();
826436c6c9cSStella Laurenzo   operation.checkValid();
827436c6c9cSStella Laurenzo   if (fileObject.is_none())
828436c6c9cSStella Laurenzo     fileObject = py::module::import("sys").attr("stdout");
829436c6c9cSStella Laurenzo 
830436c6c9cSStella Laurenzo   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
831436c6c9cSStella Laurenzo     fileObject.attr("write")("// Verification failed, printing generic form\n");
832436c6c9cSStella Laurenzo     printGenericOpForm = true;
833436c6c9cSStella Laurenzo   }
834436c6c9cSStella Laurenzo 
835436c6c9cSStella Laurenzo   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
836436c6c9cSStella Laurenzo   if (largeElementsLimit)
837436c6c9cSStella Laurenzo     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
838436c6c9cSStella Laurenzo   if (enableDebugInfo)
839436c6c9cSStella Laurenzo     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
840436c6c9cSStella Laurenzo   if (printGenericOpForm)
841436c6c9cSStella Laurenzo     mlirOpPrintingFlagsPrintGenericOpForm(flags);
842436c6c9cSStella Laurenzo 
843436c6c9cSStella Laurenzo   PyFileAccumulator accum(fileObject, binary);
844436c6c9cSStella Laurenzo   py::gil_scoped_release();
845436c6c9cSStella Laurenzo   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
846436c6c9cSStella Laurenzo                               accum.getUserData());
847436c6c9cSStella Laurenzo   mlirOpPrintingFlagsDestroy(flags);
848436c6c9cSStella Laurenzo }
849436c6c9cSStella Laurenzo 
850436c6c9cSStella Laurenzo py::object PyOperationBase::getAsm(bool binary,
851436c6c9cSStella Laurenzo                                    llvm::Optional<int64_t> largeElementsLimit,
852436c6c9cSStella Laurenzo                                    bool enableDebugInfo, bool prettyDebugInfo,
853436c6c9cSStella Laurenzo                                    bool printGenericOpForm,
854436c6c9cSStella Laurenzo                                    bool useLocalScope) {
855436c6c9cSStella Laurenzo   py::object fileObject;
856436c6c9cSStella Laurenzo   if (binary) {
857436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("BytesIO")();
858436c6c9cSStella Laurenzo   } else {
859436c6c9cSStella Laurenzo     fileObject = py::module::import("io").attr("StringIO")();
860436c6c9cSStella Laurenzo   }
861436c6c9cSStella Laurenzo   print(fileObject, /*binary=*/binary,
862436c6c9cSStella Laurenzo         /*largeElementsLimit=*/largeElementsLimit,
863436c6c9cSStella Laurenzo         /*enableDebugInfo=*/enableDebugInfo,
864436c6c9cSStella Laurenzo         /*prettyDebugInfo=*/prettyDebugInfo,
865436c6c9cSStella Laurenzo         /*printGenericOpForm=*/printGenericOpForm,
866436c6c9cSStella Laurenzo         /*useLocalScope=*/useLocalScope);
867436c6c9cSStella Laurenzo 
868436c6c9cSStella Laurenzo   return fileObject.attr("getvalue")();
869436c6c9cSStella Laurenzo }
870436c6c9cSStella Laurenzo 
8711689dadeSJohn Demme llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
87249745f87SMike Urbach   checkValid();
873436c6c9cSStella Laurenzo   if (!isAttached())
874436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
875436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationGetParentOperation(get());
876436c6c9cSStella Laurenzo   if (mlirOperationIsNull(operation))
8771689dadeSJohn Demme     return {};
878436c6c9cSStella Laurenzo   return PyOperation::forOperation(getContext(), operation);
879436c6c9cSStella Laurenzo }
880436c6c9cSStella Laurenzo 
881436c6c9cSStella Laurenzo PyBlock PyOperation::getBlock() {
88249745f87SMike Urbach   checkValid();
8831689dadeSJohn Demme   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
884436c6c9cSStella Laurenzo   MlirBlock block = mlirOperationGetBlock(get());
885436c6c9cSStella Laurenzo   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
8861689dadeSJohn Demme   assert(parentOperation && "Operation has no parent");
8871689dadeSJohn Demme   return PyBlock{std::move(*parentOperation), block};
888436c6c9cSStella Laurenzo }
889436c6c9cSStella Laurenzo 
8900126e906SJohn Demme py::object PyOperation::getCapsule() {
89149745f87SMike Urbach   checkValid();
8920126e906SJohn Demme   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
8930126e906SJohn Demme }
8940126e906SJohn Demme 
8950126e906SJohn Demme py::object PyOperation::createFromCapsule(py::object capsule) {
8960126e906SJohn Demme   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
8970126e906SJohn Demme   if (mlirOperationIsNull(rawOperation))
8980126e906SJohn Demme     throw py::error_already_set();
8990126e906SJohn Demme   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
9000126e906SJohn Demme   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
9010126e906SJohn Demme       .releaseObject();
9020126e906SJohn Demme }
9030126e906SJohn Demme 
904436c6c9cSStella Laurenzo py::object PyOperation::create(
905436c6c9cSStella Laurenzo     std::string name, llvm::Optional<std::vector<PyType *>> results,
906436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyValue *>> operands,
907436c6c9cSStella Laurenzo     llvm::Optional<py::dict> attributes,
908436c6c9cSStella Laurenzo     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
909436c6c9cSStella Laurenzo     DefaultingPyLocation location, py::object maybeIp) {
910436c6c9cSStella Laurenzo   llvm::SmallVector<MlirValue, 4> mlirOperands;
911436c6c9cSStella Laurenzo   llvm::SmallVector<MlirType, 4> mlirResults;
912436c6c9cSStella Laurenzo   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
913436c6c9cSStella Laurenzo   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
914436c6c9cSStella Laurenzo 
915436c6c9cSStella Laurenzo   // General parameter validation.
916436c6c9cSStella Laurenzo   if (regions < 0)
917436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
918436c6c9cSStella Laurenzo 
919436c6c9cSStella Laurenzo   // Unpack/validate operands.
920436c6c9cSStella Laurenzo   if (operands) {
921436c6c9cSStella Laurenzo     mlirOperands.reserve(operands->size());
922436c6c9cSStella Laurenzo     for (PyValue *operand : *operands) {
923436c6c9cSStella Laurenzo       if (!operand)
924436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
925436c6c9cSStella Laurenzo       mlirOperands.push_back(operand->get());
926436c6c9cSStella Laurenzo     }
927436c6c9cSStella Laurenzo   }
928436c6c9cSStella Laurenzo 
929436c6c9cSStella Laurenzo   // Unpack/validate results.
930436c6c9cSStella Laurenzo   if (results) {
931436c6c9cSStella Laurenzo     mlirResults.reserve(results->size());
932436c6c9cSStella Laurenzo     for (PyType *result : *results) {
933436c6c9cSStella Laurenzo       // TODO: Verify result type originate from the same context.
934436c6c9cSStella Laurenzo       if (!result)
935436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "result type cannot be None");
936436c6c9cSStella Laurenzo       mlirResults.push_back(*result);
937436c6c9cSStella Laurenzo     }
938436c6c9cSStella Laurenzo   }
939436c6c9cSStella Laurenzo   // Unpack/validate attributes.
940436c6c9cSStella Laurenzo   if (attributes) {
941436c6c9cSStella Laurenzo     mlirAttributes.reserve(attributes->size());
942436c6c9cSStella Laurenzo     for (auto &it : *attributes) {
943436c6c9cSStella Laurenzo       std::string key;
944436c6c9cSStella Laurenzo       try {
945436c6c9cSStella Laurenzo         key = it.first.cast<std::string>();
946436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
947436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute key (not a string) when "
948436c6c9cSStella Laurenzo                           "attempting to create the operation \"" +
949436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
950436c6c9cSStella Laurenzo         throw py::cast_error(msg);
951436c6c9cSStella Laurenzo       }
952436c6c9cSStella Laurenzo       try {
953436c6c9cSStella Laurenzo         auto &attribute = it.second.cast<PyAttribute &>();
954436c6c9cSStella Laurenzo         // TODO: Verify attribute originates from the same context.
955436c6c9cSStella Laurenzo         mlirAttributes.emplace_back(std::move(key), attribute);
956436c6c9cSStella Laurenzo       } catch (py::reference_cast_error &) {
957436c6c9cSStella Laurenzo         // This exception seems thrown when the value is "None".
958436c6c9cSStella Laurenzo         std::string msg =
959436c6c9cSStella Laurenzo             "Found an invalid (`None`?) attribute value for the key \"" + key +
960436c6c9cSStella Laurenzo             "\" when attempting to create the operation \"" + name + "\"";
961436c6c9cSStella Laurenzo         throw py::cast_error(msg);
962436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
963436c6c9cSStella Laurenzo         std::string msg = "Invalid attribute value for the key \"" + key +
964436c6c9cSStella Laurenzo                           "\" when attempting to create the operation \"" +
965436c6c9cSStella Laurenzo                           name + "\" (" + err.what() + ")";
966436c6c9cSStella Laurenzo         throw py::cast_error(msg);
967436c6c9cSStella Laurenzo       }
968436c6c9cSStella Laurenzo     }
969436c6c9cSStella Laurenzo   }
970436c6c9cSStella Laurenzo   // Unpack/validate successors.
971436c6c9cSStella Laurenzo   if (successors) {
972436c6c9cSStella Laurenzo     mlirSuccessors.reserve(successors->size());
973436c6c9cSStella Laurenzo     for (auto *successor : *successors) {
974436c6c9cSStella Laurenzo       // TODO: Verify successor originate from the same context.
975436c6c9cSStella Laurenzo       if (!successor)
976436c6c9cSStella Laurenzo         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
977436c6c9cSStella Laurenzo       mlirSuccessors.push_back(successor->get());
978436c6c9cSStella Laurenzo     }
979436c6c9cSStella Laurenzo   }
980436c6c9cSStella Laurenzo 
981436c6c9cSStella Laurenzo   // Apply unpacked/validated to the operation state. Beyond this
982436c6c9cSStella Laurenzo   // point, exceptions cannot be thrown or else the state will leak.
983436c6c9cSStella Laurenzo   MlirOperationState state =
984436c6c9cSStella Laurenzo       mlirOperationStateGet(toMlirStringRef(name), location);
985436c6c9cSStella Laurenzo   if (!mlirOperands.empty())
986436c6c9cSStella Laurenzo     mlirOperationStateAddOperands(&state, mlirOperands.size(),
987436c6c9cSStella Laurenzo                                   mlirOperands.data());
988436c6c9cSStella Laurenzo   if (!mlirResults.empty())
989436c6c9cSStella Laurenzo     mlirOperationStateAddResults(&state, mlirResults.size(),
990436c6c9cSStella Laurenzo                                  mlirResults.data());
991436c6c9cSStella Laurenzo   if (!mlirAttributes.empty()) {
992436c6c9cSStella Laurenzo     // Note that the attribute names directly reference bytes in
993436c6c9cSStella Laurenzo     // mlirAttributes, so that vector must not be changed from here
994436c6c9cSStella Laurenzo     // on.
995436c6c9cSStella Laurenzo     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
996436c6c9cSStella Laurenzo     mlirNamedAttributes.reserve(mlirAttributes.size());
997436c6c9cSStella Laurenzo     for (auto &it : mlirAttributes)
998436c6c9cSStella Laurenzo       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
999436c6c9cSStella Laurenzo           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1000436c6c9cSStella Laurenzo                             toMlirStringRef(it.first)),
1001436c6c9cSStella Laurenzo           it.second));
1002436c6c9cSStella Laurenzo     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1003436c6c9cSStella Laurenzo                                     mlirNamedAttributes.data());
1004436c6c9cSStella Laurenzo   }
1005436c6c9cSStella Laurenzo   if (!mlirSuccessors.empty())
1006436c6c9cSStella Laurenzo     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1007436c6c9cSStella Laurenzo                                     mlirSuccessors.data());
1008436c6c9cSStella Laurenzo   if (regions) {
1009436c6c9cSStella Laurenzo     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1010436c6c9cSStella Laurenzo     mlirRegions.resize(regions);
1011436c6c9cSStella Laurenzo     for (int i = 0; i < regions; ++i)
1012436c6c9cSStella Laurenzo       mlirRegions[i] = mlirRegionCreate();
1013436c6c9cSStella Laurenzo     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1014436c6c9cSStella Laurenzo                                       mlirRegions.data());
1015436c6c9cSStella Laurenzo   }
1016436c6c9cSStella Laurenzo 
1017436c6c9cSStella Laurenzo   // Construct the operation.
1018436c6c9cSStella Laurenzo   MlirOperation operation = mlirOperationCreate(&state);
1019436c6c9cSStella Laurenzo   PyOperationRef created =
1020436c6c9cSStella Laurenzo       PyOperation::createDetached(location->getContext(), operation);
1021436c6c9cSStella Laurenzo 
1022436c6c9cSStella Laurenzo   // InsertPoint active?
1023436c6c9cSStella Laurenzo   if (!maybeIp.is(py::cast(false))) {
1024436c6c9cSStella Laurenzo     PyInsertionPoint *ip;
1025436c6c9cSStella Laurenzo     if (maybeIp.is_none()) {
1026436c6c9cSStella Laurenzo       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1027436c6c9cSStella Laurenzo     } else {
1028436c6c9cSStella Laurenzo       ip = py::cast<PyInsertionPoint *>(maybeIp);
1029436c6c9cSStella Laurenzo     }
1030436c6c9cSStella Laurenzo     if (ip)
1031436c6c9cSStella Laurenzo       ip->insert(*created.get());
1032436c6c9cSStella Laurenzo   }
1033436c6c9cSStella Laurenzo 
1034436c6c9cSStella Laurenzo   return created->createOpView();
1035436c6c9cSStella Laurenzo }
1036436c6c9cSStella Laurenzo 
1037436c6c9cSStella Laurenzo py::object PyOperation::createOpView() {
103849745f87SMike Urbach   checkValid();
1039436c6c9cSStella Laurenzo   MlirIdentifier ident = mlirOperationGetName(get());
1040436c6c9cSStella Laurenzo   MlirStringRef identStr = mlirIdentifierStr(ident);
1041436c6c9cSStella Laurenzo   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1042436c6c9cSStella Laurenzo       StringRef(identStr.data, identStr.length));
1043436c6c9cSStella Laurenzo   if (opViewClass)
1044436c6c9cSStella Laurenzo     return (*opViewClass)(getRef().getObject());
1045436c6c9cSStella Laurenzo   return py::cast(PyOpView(getRef().getObject()));
1046436c6c9cSStella Laurenzo }
1047436c6c9cSStella Laurenzo 
104849745f87SMike Urbach void PyOperation::erase() {
104949745f87SMike Urbach   checkValid();
105049745f87SMike Urbach   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
105149745f87SMike Urbach   // Python reference to a child operation is live. All children should also
105249745f87SMike Urbach   // have their `valid` bit set to false.
105349745f87SMike Urbach   auto &liveOperations = getContext()->liveOperations;
105449745f87SMike Urbach   if (liveOperations.count(operation.ptr))
105549745f87SMike Urbach     liveOperations.erase(operation.ptr);
105649745f87SMike Urbach   mlirOperationDestroy(operation);
105749745f87SMike Urbach   valid = false;
105849745f87SMike Urbach }
105949745f87SMike Urbach 
1060436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1061436c6c9cSStella Laurenzo // PyOpView
1062436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1063436c6c9cSStella Laurenzo 
1064436c6c9cSStella Laurenzo py::object
1065436c6c9cSStella Laurenzo PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1066436c6c9cSStella Laurenzo                        py::list operandList,
1067436c6c9cSStella Laurenzo                        llvm::Optional<py::dict> attributes,
1068436c6c9cSStella Laurenzo                        llvm::Optional<std::vector<PyBlock *>> successors,
1069436c6c9cSStella Laurenzo                        llvm::Optional<int> regions,
1070436c6c9cSStella Laurenzo                        DefaultingPyLocation location, py::object maybeIp) {
1071436c6c9cSStella Laurenzo   PyMlirContextRef context = location->getContext();
1072436c6c9cSStella Laurenzo   // Class level operation construction metadata.
1073436c6c9cSStella Laurenzo   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1074436c6c9cSStella Laurenzo   // Operand and result segment specs are either none, which does no
1075436c6c9cSStella Laurenzo   // variadic unpacking, or a list of ints with segment sizes, where each
1076436c6c9cSStella Laurenzo   // element is either a positive number (typically 1 for a scalar) or -1 to
1077436c6c9cSStella Laurenzo   // indicate that it is derived from the length of the same-indexed operand
1078436c6c9cSStella Laurenzo   // or result (implying that it is a list at that position).
1079436c6c9cSStella Laurenzo   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1080436c6c9cSStella Laurenzo   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1081436c6c9cSStella Laurenzo 
10828d05a288SStella Laurenzo   std::vector<uint32_t> operandSegmentLengths;
10838d05a288SStella Laurenzo   std::vector<uint32_t> resultSegmentLengths;
1084436c6c9cSStella Laurenzo 
1085436c6c9cSStella Laurenzo   // Validate/determine region count.
1086436c6c9cSStella Laurenzo   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1087436c6c9cSStella Laurenzo   int opMinRegionCount = std::get<0>(opRegionSpec);
1088436c6c9cSStella Laurenzo   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1089436c6c9cSStella Laurenzo   if (!regions) {
1090436c6c9cSStella Laurenzo     regions = opMinRegionCount;
1091436c6c9cSStella Laurenzo   }
1092436c6c9cSStella Laurenzo   if (*regions < opMinRegionCount) {
1093436c6c9cSStella Laurenzo     throw py::value_error(
1094436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1095436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1096436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1097436c6c9cSStella Laurenzo             .str());
1098436c6c9cSStella Laurenzo   }
1099436c6c9cSStella Laurenzo   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1100436c6c9cSStella Laurenzo     throw py::value_error(
1101436c6c9cSStella Laurenzo         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1102436c6c9cSStella Laurenzo          llvm::Twine(opMinRegionCount) +
1103436c6c9cSStella Laurenzo          " regions but was built with regions=" + llvm::Twine(*regions))
1104436c6c9cSStella Laurenzo             .str());
1105436c6c9cSStella Laurenzo   }
1106436c6c9cSStella Laurenzo 
1107436c6c9cSStella Laurenzo   // Unpack results.
1108436c6c9cSStella Laurenzo   std::vector<PyType *> resultTypes;
1109436c6c9cSStella Laurenzo   resultTypes.reserve(resultTypeList.size());
1110436c6c9cSStella Laurenzo   if (resultSegmentSpecObj.is_none()) {
1111436c6c9cSStella Laurenzo     // Non-variadic result unpacking.
1112436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(resultTypeList)) {
1113436c6c9cSStella Laurenzo       try {
1114436c6c9cSStella Laurenzo         resultTypes.push_back(py::cast<PyType *>(it.value()));
1115436c6c9cSStella Laurenzo         if (!resultTypes.back())
1116436c6c9cSStella Laurenzo           throw py::cast_error();
1117436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1118436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Result ") +
1119436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1120436c6c9cSStella Laurenzo                                name + "\" must be a Type (" + err.what() + ")")
1121436c6c9cSStella Laurenzo                                   .str());
1122436c6c9cSStella Laurenzo       }
1123436c6c9cSStella Laurenzo     }
1124436c6c9cSStella Laurenzo   } else {
1125436c6c9cSStella Laurenzo     // Sized result unpacking.
1126436c6c9cSStella Laurenzo     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1127436c6c9cSStella Laurenzo     if (resultSegmentSpec.size() != resultTypeList.size()) {
1128436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1129436c6c9cSStella Laurenzo                              "\" requires " +
1130436c6c9cSStella Laurenzo                              llvm::Twine(resultSegmentSpec.size()) +
1131436c6c9cSStella Laurenzo                              "result segments but was provided " +
1132436c6c9cSStella Laurenzo                              llvm::Twine(resultTypeList.size()))
1133436c6c9cSStella Laurenzo                                 .str());
1134436c6c9cSStella Laurenzo     }
1135436c6c9cSStella Laurenzo     resultSegmentLengths.reserve(resultTypeList.size());
1136436c6c9cSStella Laurenzo     for (auto it :
1137436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1138436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1139436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1140436c6c9cSStella Laurenzo         // Unpack unary element.
1141436c6c9cSStella Laurenzo         try {
1142436c6c9cSStella Laurenzo           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1143436c6c9cSStella Laurenzo           if (resultType) {
1144436c6c9cSStella Laurenzo             resultTypes.push_back(resultType);
1145436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(1);
1146436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1147436c6c9cSStella Laurenzo             // Allowed to be optional.
1148436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1149436c6c9cSStella Laurenzo           } else {
1150436c6c9cSStella Laurenzo             throw py::cast_error("was None and result is not optional");
1151436c6c9cSStella Laurenzo           }
1152436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1153436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1154436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1155436c6c9cSStella Laurenzo                                  name + "\" must be a Type (" + err.what() +
1156436c6c9cSStella Laurenzo                                  ")")
1157436c6c9cSStella Laurenzo                                     .str());
1158436c6c9cSStella Laurenzo         }
1159436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1160436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1161436c6c9cSStella Laurenzo         try {
1162436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1163436c6c9cSStella Laurenzo             // Treat it as an empty list.
1164436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(0);
1165436c6c9cSStella Laurenzo           } else {
1166436c6c9cSStella Laurenzo             // Unpack the list.
1167436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1168436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1169436c6c9cSStella Laurenzo               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1170436c6c9cSStella Laurenzo               if (!resultTypes.back()) {
1171436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1172436c6c9cSStella Laurenzo               }
1173436c6c9cSStella Laurenzo             }
1174436c6c9cSStella Laurenzo             resultSegmentLengths.push_back(segment.size());
1175436c6c9cSStella Laurenzo           }
1176436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1177436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1178436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1179436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1180436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Result ") +
1181436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1182436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Types (" +
1183436c6c9cSStella Laurenzo                                  err.what() + ")")
1184436c6c9cSStella Laurenzo                                     .str());
1185436c6c9cSStella Laurenzo         }
1186436c6c9cSStella Laurenzo       } else {
1187436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1188436c6c9cSStella Laurenzo       }
1189436c6c9cSStella Laurenzo     }
1190436c6c9cSStella Laurenzo   }
1191436c6c9cSStella Laurenzo 
1192436c6c9cSStella Laurenzo   // Unpack operands.
1193436c6c9cSStella Laurenzo   std::vector<PyValue *> operands;
1194436c6c9cSStella Laurenzo   operands.reserve(operands.size());
1195436c6c9cSStella Laurenzo   if (operandSegmentSpecObj.is_none()) {
1196436c6c9cSStella Laurenzo     // Non-sized operand unpacking.
1197436c6c9cSStella Laurenzo     for (auto it : llvm::enumerate(operandList)) {
1198436c6c9cSStella Laurenzo       try {
1199436c6c9cSStella Laurenzo         operands.push_back(py::cast<PyValue *>(it.value()));
1200436c6c9cSStella Laurenzo         if (!operands.back())
1201436c6c9cSStella Laurenzo           throw py::cast_error();
1202436c6c9cSStella Laurenzo       } catch (py::cast_error &err) {
1203436c6c9cSStella Laurenzo         throw py::value_error((llvm::Twine("Operand ") +
1204436c6c9cSStella Laurenzo                                llvm::Twine(it.index()) + " of operation \"" +
1205436c6c9cSStella Laurenzo                                name + "\" must be a Value (" + err.what() + ")")
1206436c6c9cSStella Laurenzo                                   .str());
1207436c6c9cSStella Laurenzo       }
1208436c6c9cSStella Laurenzo     }
1209436c6c9cSStella Laurenzo   } else {
1210436c6c9cSStella Laurenzo     // Sized operand unpacking.
1211436c6c9cSStella Laurenzo     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1212436c6c9cSStella Laurenzo     if (operandSegmentSpec.size() != operandList.size()) {
1213436c6c9cSStella Laurenzo       throw py::value_error((llvm::Twine("Operation \"") + name +
1214436c6c9cSStella Laurenzo                              "\" requires " +
1215436c6c9cSStella Laurenzo                              llvm::Twine(operandSegmentSpec.size()) +
1216436c6c9cSStella Laurenzo                              "operand segments but was provided " +
1217436c6c9cSStella Laurenzo                              llvm::Twine(operandList.size()))
1218436c6c9cSStella Laurenzo                                 .str());
1219436c6c9cSStella Laurenzo     }
1220436c6c9cSStella Laurenzo     operandSegmentLengths.reserve(operandList.size());
1221436c6c9cSStella Laurenzo     for (auto it :
1222436c6c9cSStella Laurenzo          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1223436c6c9cSStella Laurenzo       int segmentSpec = std::get<1>(it.value());
1224436c6c9cSStella Laurenzo       if (segmentSpec == 1 || segmentSpec == 0) {
1225436c6c9cSStella Laurenzo         // Unpack unary element.
1226436c6c9cSStella Laurenzo         try {
1227436c6c9cSStella Laurenzo           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1228436c6c9cSStella Laurenzo           if (operandValue) {
1229436c6c9cSStella Laurenzo             operands.push_back(operandValue);
1230436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(1);
1231436c6c9cSStella Laurenzo           } else if (segmentSpec == 0) {
1232436c6c9cSStella Laurenzo             // Allowed to be optional.
1233436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1234436c6c9cSStella Laurenzo           } else {
1235436c6c9cSStella Laurenzo             throw py::cast_error("was None and operand is not optional");
1236436c6c9cSStella Laurenzo           }
1237436c6c9cSStella Laurenzo         } catch (py::cast_error &err) {
1238436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1239436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1240436c6c9cSStella Laurenzo                                  name + "\" must be a Value (" + err.what() +
1241436c6c9cSStella Laurenzo                                  ")")
1242436c6c9cSStella Laurenzo                                     .str());
1243436c6c9cSStella Laurenzo         }
1244436c6c9cSStella Laurenzo       } else if (segmentSpec == -1) {
1245436c6c9cSStella Laurenzo         // Unpack sequence by appending.
1246436c6c9cSStella Laurenzo         try {
1247436c6c9cSStella Laurenzo           if (std::get<0>(it.value()).is_none()) {
1248436c6c9cSStella Laurenzo             // Treat it as an empty list.
1249436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(0);
1250436c6c9cSStella Laurenzo           } else {
1251436c6c9cSStella Laurenzo             // Unpack the list.
1252436c6c9cSStella Laurenzo             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1253436c6c9cSStella Laurenzo             for (py::object segmentItem : segment) {
1254436c6c9cSStella Laurenzo               operands.push_back(py::cast<PyValue *>(segmentItem));
1255436c6c9cSStella Laurenzo               if (!operands.back()) {
1256436c6c9cSStella Laurenzo                 throw py::cast_error("contained a None item");
1257436c6c9cSStella Laurenzo               }
1258436c6c9cSStella Laurenzo             }
1259436c6c9cSStella Laurenzo             operandSegmentLengths.push_back(segment.size());
1260436c6c9cSStella Laurenzo           }
1261436c6c9cSStella Laurenzo         } catch (std::exception &err) {
1262436c6c9cSStella Laurenzo           // NOTE: Sloppy to be using a catch-all here, but there are at least
1263436c6c9cSStella Laurenzo           // three different unrelated exceptions that can be thrown in the
1264436c6c9cSStella Laurenzo           // above "casts". Just keep the scope above small and catch them all.
1265436c6c9cSStella Laurenzo           throw py::value_error((llvm::Twine("Operand ") +
1266436c6c9cSStella Laurenzo                                  llvm::Twine(it.index()) + " of operation \"" +
1267436c6c9cSStella Laurenzo                                  name + "\" must be a Sequence of Values (" +
1268436c6c9cSStella Laurenzo                                  err.what() + ")")
1269436c6c9cSStella Laurenzo                                     .str());
1270436c6c9cSStella Laurenzo         }
1271436c6c9cSStella Laurenzo       } else {
1272436c6c9cSStella Laurenzo         throw py::value_error("Unexpected segment spec");
1273436c6c9cSStella Laurenzo       }
1274436c6c9cSStella Laurenzo     }
1275436c6c9cSStella Laurenzo   }
1276436c6c9cSStella Laurenzo 
1277436c6c9cSStella Laurenzo   // Merge operand/result segment lengths into attributes if needed.
1278436c6c9cSStella Laurenzo   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1279436c6c9cSStella Laurenzo     // Dup.
1280436c6c9cSStella Laurenzo     if (attributes) {
1281436c6c9cSStella Laurenzo       attributes = py::dict(*attributes);
1282436c6c9cSStella Laurenzo     } else {
1283436c6c9cSStella Laurenzo       attributes = py::dict();
1284436c6c9cSStella Laurenzo     }
1285436c6c9cSStella Laurenzo     if (attributes->contains("result_segment_sizes") ||
1286436c6c9cSStella Laurenzo         attributes->contains("operand_segment_sizes")) {
1287436c6c9cSStella Laurenzo       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1288436c6c9cSStella Laurenzo                             "'operand_segment_sizes' attribute is unsupported. "
1289436c6c9cSStella Laurenzo                             "Use Operation.create for such low-level access.");
1290436c6c9cSStella Laurenzo     }
1291436c6c9cSStella Laurenzo 
1292436c6c9cSStella Laurenzo     // Add result_segment_sizes attribute.
1293436c6c9cSStella Laurenzo     if (!resultSegmentLengths.empty()) {
1294436c6c9cSStella Laurenzo       int64_t size = resultSegmentLengths.size();
12958d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
12968d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1297436c6c9cSStella Laurenzo           resultSegmentLengths.size(), resultSegmentLengths.data());
1298436c6c9cSStella Laurenzo       (*attributes)["result_segment_sizes"] =
1299436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1300436c6c9cSStella Laurenzo     }
1301436c6c9cSStella Laurenzo 
1302436c6c9cSStella Laurenzo     // Add operand_segment_sizes attribute.
1303436c6c9cSStella Laurenzo     if (!operandSegmentLengths.empty()) {
1304436c6c9cSStella Laurenzo       int64_t size = operandSegmentLengths.size();
13058d05a288SStella Laurenzo       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
13068d05a288SStella Laurenzo           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1307436c6c9cSStella Laurenzo           operandSegmentLengths.size(), operandSegmentLengths.data());
1308436c6c9cSStella Laurenzo       (*attributes)["operand_segment_sizes"] =
1309436c6c9cSStella Laurenzo           PyAttribute(context, segmentLengthAttr);
1310436c6c9cSStella Laurenzo     }
1311436c6c9cSStella Laurenzo   }
1312436c6c9cSStella Laurenzo 
1313436c6c9cSStella Laurenzo   // Delegate to create.
1314436c6c9cSStella Laurenzo   return PyOperation::create(std::move(name),
1315436c6c9cSStella Laurenzo                              /*results=*/std::move(resultTypes),
1316436c6c9cSStella Laurenzo                              /*operands=*/std::move(operands),
1317436c6c9cSStella Laurenzo                              /*attributes=*/std::move(attributes),
1318436c6c9cSStella Laurenzo                              /*successors=*/std::move(successors),
1319436c6c9cSStella Laurenzo                              /*regions=*/*regions, location, maybeIp);
1320436c6c9cSStella Laurenzo }
1321436c6c9cSStella Laurenzo 
1322436c6c9cSStella Laurenzo PyOpView::PyOpView(py::object operationObject)
1323436c6c9cSStella Laurenzo     // Casting through the PyOperationBase base-class and then back to the
1324436c6c9cSStella Laurenzo     // Operation lets us accept any PyOperationBase subclass.
1325436c6c9cSStella Laurenzo     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1326436c6c9cSStella Laurenzo       operationObject(operation.getRef().getObject()) {}
1327436c6c9cSStella Laurenzo 
1328436c6c9cSStella Laurenzo py::object PyOpView::createRawSubclass(py::object userClass) {
1329436c6c9cSStella Laurenzo   // This is... a little gross. The typical pattern is to have a pure python
1330436c6c9cSStella Laurenzo   // class that extends OpView like:
1331436c6c9cSStella Laurenzo   //   class AddFOp(_cext.ir.OpView):
1332436c6c9cSStella Laurenzo   //     def __init__(self, loc, lhs, rhs):
1333436c6c9cSStella Laurenzo   //       operation = loc.context.create_operation(
1334436c6c9cSStella Laurenzo   //           "addf", lhs, rhs, results=[lhs.type])
1335436c6c9cSStella Laurenzo   //       super().__init__(operation)
1336436c6c9cSStella Laurenzo   //
1337436c6c9cSStella Laurenzo   // I.e. The goal of the user facing type is to provide a nice constructor
1338436c6c9cSStella Laurenzo   // that has complete freedom for the op under construction. This is at odds
1339436c6c9cSStella Laurenzo   // with our other desire to sometimes create this object by just passing an
1340436c6c9cSStella Laurenzo   // operation (to initialize the base class). We could do *arg and **kwargs
1341436c6c9cSStella Laurenzo   // munging to try to make it work, but instead, we synthesize a new class
1342436c6c9cSStella Laurenzo   // on the fly which extends this user class (AddFOp in this example) and
1343436c6c9cSStella Laurenzo   // *give it* the base class's __init__ method, thus bypassing the
1344436c6c9cSStella Laurenzo   // intermediate subclass's __init__ method entirely. While slightly,
1345436c6c9cSStella Laurenzo   // underhanded, this is safe/legal because the type hierarchy has not changed
1346436c6c9cSStella Laurenzo   // (we just added a new leaf) and we aren't mucking around with __new__.
1347436c6c9cSStella Laurenzo   // Typically, this new class will be stored on the original as "_Raw" and will
1348436c6c9cSStella Laurenzo   // be used for casts and other things that need a variant of the class that
1349436c6c9cSStella Laurenzo   // is initialized purely from an operation.
1350436c6c9cSStella Laurenzo   py::object parentMetaclass =
1351436c6c9cSStella Laurenzo       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1352436c6c9cSStella Laurenzo   py::dict attributes;
1353436c6c9cSStella Laurenzo   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1354436c6c9cSStella Laurenzo   // now.
1355436c6c9cSStella Laurenzo   //   auto opViewType = py::type::of<PyOpView>();
1356436c6c9cSStella Laurenzo   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1357436c6c9cSStella Laurenzo   attributes["__init__"] = opViewType.attr("__init__");
1358436c6c9cSStella Laurenzo   py::str origName = userClass.attr("__name__");
1359436c6c9cSStella Laurenzo   py::str newName = py::str("_") + origName;
1360436c6c9cSStella Laurenzo   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1361436c6c9cSStella Laurenzo }
1362436c6c9cSStella Laurenzo 
1363436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1364436c6c9cSStella Laurenzo // PyInsertionPoint.
1365436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1366436c6c9cSStella Laurenzo 
1367436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1368436c6c9cSStella Laurenzo 
1369436c6c9cSStella Laurenzo PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1370436c6c9cSStella Laurenzo     : refOperation(beforeOperationBase.getOperation().getRef()),
1371436c6c9cSStella Laurenzo       block((*refOperation)->getBlock()) {}
1372436c6c9cSStella Laurenzo 
1373436c6c9cSStella Laurenzo void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1374436c6c9cSStella Laurenzo   PyOperation &operation = operationBase.getOperation();
1375436c6c9cSStella Laurenzo   if (operation.isAttached())
1376436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError,
1377436c6c9cSStella Laurenzo                      "Attempt to insert operation that is already attached");
1378436c6c9cSStella Laurenzo   block.getParentOperation()->checkValid();
1379436c6c9cSStella Laurenzo   MlirOperation beforeOp = {nullptr};
1380436c6c9cSStella Laurenzo   if (refOperation) {
1381436c6c9cSStella Laurenzo     // Insert before operation.
1382436c6c9cSStella Laurenzo     (*refOperation)->checkValid();
1383436c6c9cSStella Laurenzo     beforeOp = (*refOperation)->get();
1384436c6c9cSStella Laurenzo   } else {
1385436c6c9cSStella Laurenzo     // Insert at end (before null) is only valid if the block does not
1386436c6c9cSStella Laurenzo     // already end in a known terminator (violating this will cause assertion
1387436c6c9cSStella Laurenzo     // failures later).
1388436c6c9cSStella Laurenzo     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1389436c6c9cSStella Laurenzo       throw py::index_error("Cannot insert operation at the end of a block "
1390436c6c9cSStella Laurenzo                             "that already has a terminator. Did you mean to "
1391436c6c9cSStella Laurenzo                             "use 'InsertionPoint.at_block_terminator(block)' "
1392436c6c9cSStella Laurenzo                             "versus 'InsertionPoint(block)'?");
1393436c6c9cSStella Laurenzo     }
1394436c6c9cSStella Laurenzo   }
1395436c6c9cSStella Laurenzo   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1396436c6c9cSStella Laurenzo   operation.setAttached();
1397436c6c9cSStella Laurenzo }
1398436c6c9cSStella Laurenzo 
1399436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1400436c6c9cSStella Laurenzo   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1401436c6c9cSStella Laurenzo   if (mlirOperationIsNull(firstOp)) {
1402436c6c9cSStella Laurenzo     // Just insert at end.
1403436c6c9cSStella Laurenzo     return PyInsertionPoint(block);
1404436c6c9cSStella Laurenzo   }
1405436c6c9cSStella Laurenzo 
1406436c6c9cSStella Laurenzo   // Insert before first op.
1407436c6c9cSStella Laurenzo   PyOperationRef firstOpRef = PyOperation::forOperation(
1408436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), firstOp);
1409436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(firstOpRef)};
1410436c6c9cSStella Laurenzo }
1411436c6c9cSStella Laurenzo 
1412436c6c9cSStella Laurenzo PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1413436c6c9cSStella Laurenzo   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1414436c6c9cSStella Laurenzo   if (mlirOperationIsNull(terminator))
1415436c6c9cSStella Laurenzo     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1416436c6c9cSStella Laurenzo   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1417436c6c9cSStella Laurenzo       block.getParentOperation()->getContext(), terminator);
1418436c6c9cSStella Laurenzo   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1419436c6c9cSStella Laurenzo }
1420436c6c9cSStella Laurenzo 
1421436c6c9cSStella Laurenzo py::object PyInsertionPoint::contextEnter() {
1422436c6c9cSStella Laurenzo   return PyThreadContextEntry::pushInsertionPoint(*this);
1423436c6c9cSStella Laurenzo }
1424436c6c9cSStella Laurenzo 
1425436c6c9cSStella Laurenzo void PyInsertionPoint::contextExit(pybind11::object excType,
1426436c6c9cSStella Laurenzo                                    pybind11::object excVal,
1427436c6c9cSStella Laurenzo                                    pybind11::object excTb) {
1428436c6c9cSStella Laurenzo   PyThreadContextEntry::popInsertionPoint(*this);
1429436c6c9cSStella Laurenzo }
1430436c6c9cSStella Laurenzo 
1431436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1432436c6c9cSStella Laurenzo // PyAttribute.
1433436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1434436c6c9cSStella Laurenzo 
1435436c6c9cSStella Laurenzo bool PyAttribute::operator==(const PyAttribute &other) {
1436436c6c9cSStella Laurenzo   return mlirAttributeEqual(attr, other.attr);
1437436c6c9cSStella Laurenzo }
1438436c6c9cSStella Laurenzo 
1439436c6c9cSStella Laurenzo py::object PyAttribute::getCapsule() {
1440436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1441436c6c9cSStella Laurenzo }
1442436c6c9cSStella Laurenzo 
1443436c6c9cSStella Laurenzo PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1444436c6c9cSStella Laurenzo   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1445436c6c9cSStella Laurenzo   if (mlirAttributeIsNull(rawAttr))
1446436c6c9cSStella Laurenzo     throw py::error_already_set();
1447436c6c9cSStella Laurenzo   return PyAttribute(
1448436c6c9cSStella Laurenzo       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1449436c6c9cSStella Laurenzo }
1450436c6c9cSStella Laurenzo 
1451436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1452436c6c9cSStella Laurenzo // PyNamedAttribute.
1453436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1454436c6c9cSStella Laurenzo 
1455436c6c9cSStella Laurenzo PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1456436c6c9cSStella Laurenzo     : ownedName(new std::string(std::move(ownedName))) {
1457436c6c9cSStella Laurenzo   namedAttr = mlirNamedAttributeGet(
1458436c6c9cSStella Laurenzo       mlirIdentifierGet(mlirAttributeGetContext(attr),
1459436c6c9cSStella Laurenzo                         toMlirStringRef(*this->ownedName)),
1460436c6c9cSStella Laurenzo       attr);
1461436c6c9cSStella Laurenzo }
1462436c6c9cSStella Laurenzo 
1463436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1464436c6c9cSStella Laurenzo // PyType.
1465436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1466436c6c9cSStella Laurenzo 
1467436c6c9cSStella Laurenzo bool PyType::operator==(const PyType &other) {
1468436c6c9cSStella Laurenzo   return mlirTypeEqual(type, other.type);
1469436c6c9cSStella Laurenzo }
1470436c6c9cSStella Laurenzo 
1471436c6c9cSStella Laurenzo py::object PyType::getCapsule() {
1472436c6c9cSStella Laurenzo   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1473436c6c9cSStella Laurenzo }
1474436c6c9cSStella Laurenzo 
1475436c6c9cSStella Laurenzo PyType PyType::createFromCapsule(py::object capsule) {
1476436c6c9cSStella Laurenzo   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1477436c6c9cSStella Laurenzo   if (mlirTypeIsNull(rawType))
1478436c6c9cSStella Laurenzo     throw py::error_already_set();
1479436c6c9cSStella Laurenzo   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1480436c6c9cSStella Laurenzo                 rawType);
1481436c6c9cSStella Laurenzo }
1482436c6c9cSStella Laurenzo 
1483436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1484436c6c9cSStella Laurenzo // PyValue and subclases.
1485436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1486436c6c9cSStella Laurenzo 
14873f3d1c90SMike Urbach pybind11::object PyValue::getCapsule() {
14883f3d1c90SMike Urbach   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
14893f3d1c90SMike Urbach }
14903f3d1c90SMike Urbach 
14913f3d1c90SMike Urbach PyValue PyValue::createFromCapsule(pybind11::object capsule) {
14923f3d1c90SMike Urbach   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
14933f3d1c90SMike Urbach   if (mlirValueIsNull(value))
14943f3d1c90SMike Urbach     throw py::error_already_set();
14953f3d1c90SMike Urbach   MlirOperation owner;
14963f3d1c90SMike Urbach   if (mlirValueIsAOpResult(value))
14973f3d1c90SMike Urbach     owner = mlirOpResultGetOwner(value);
14983f3d1c90SMike Urbach   if (mlirValueIsABlockArgument(value))
14993f3d1c90SMike Urbach     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
15003f3d1c90SMike Urbach   if (mlirOperationIsNull(owner))
15013f3d1c90SMike Urbach     throw py::error_already_set();
15023f3d1c90SMike Urbach   MlirContext ctx = mlirOperationGetContext(owner);
15033f3d1c90SMike Urbach   PyOperationRef ownerRef =
15043f3d1c90SMike Urbach       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
15053f3d1c90SMike Urbach   return PyValue(ownerRef, value);
15063f3d1c90SMike Urbach }
15073f3d1c90SMike Urbach 
1508436c6c9cSStella Laurenzo namespace {
1509436c6c9cSStella Laurenzo /// CRTP base class for Python MLIR values that subclass Value and should be
1510436c6c9cSStella Laurenzo /// castable from it. The value hierarchy is one level deep and is not supposed
1511436c6c9cSStella Laurenzo /// to accommodate other levels unless core MLIR changes.
1512436c6c9cSStella Laurenzo template <typename DerivedTy>
1513436c6c9cSStella Laurenzo class PyConcreteValue : public PyValue {
1514436c6c9cSStella Laurenzo public:
1515436c6c9cSStella Laurenzo   // Derived classes must define statics for:
1516436c6c9cSStella Laurenzo   //   IsAFunctionTy isaFunction
1517436c6c9cSStella Laurenzo   //   const char *pyClassName
1518436c6c9cSStella Laurenzo   // and redefine bindDerived.
1519436c6c9cSStella Laurenzo   using ClassTy = py::class_<DerivedTy, PyValue>;
1520436c6c9cSStella Laurenzo   using IsAFunctionTy = bool (*)(MlirValue);
1521436c6c9cSStella Laurenzo 
1522436c6c9cSStella Laurenzo   PyConcreteValue() = default;
1523436c6c9cSStella Laurenzo   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1524436c6c9cSStella Laurenzo       : PyValue(operationRef, value) {}
1525436c6c9cSStella Laurenzo   PyConcreteValue(PyValue &orig)
1526436c6c9cSStella Laurenzo       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1527436c6c9cSStella Laurenzo 
1528436c6c9cSStella Laurenzo   /// Attempts to cast the original value to the derived type and throws on
1529436c6c9cSStella Laurenzo   /// type mismatches.
1530436c6c9cSStella Laurenzo   static MlirValue castFrom(PyValue &orig) {
1531436c6c9cSStella Laurenzo     if (!DerivedTy::isaFunction(orig.get())) {
1532436c6c9cSStella Laurenzo       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1533436c6c9cSStella Laurenzo       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1534436c6c9cSStella Laurenzo                                              DerivedTy::pyClassName +
1535436c6c9cSStella Laurenzo                                              " (from " + origRepr + ")");
1536436c6c9cSStella Laurenzo     }
1537436c6c9cSStella Laurenzo     return orig.get();
1538436c6c9cSStella Laurenzo   }
1539436c6c9cSStella Laurenzo 
1540436c6c9cSStella Laurenzo   /// Binds the Python module objects to functions of this class.
1541436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1542f05ff4f7SStella Laurenzo     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1543436c6c9cSStella Laurenzo     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1544436c6c9cSStella Laurenzo     DerivedTy::bindDerived(cls);
1545436c6c9cSStella Laurenzo   }
1546436c6c9cSStella Laurenzo 
1547436c6c9cSStella Laurenzo   /// Implemented by derived classes to add methods to the Python subclass.
1548436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &m) {}
1549436c6c9cSStella Laurenzo };
1550436c6c9cSStella Laurenzo 
1551436c6c9cSStella Laurenzo /// Python wrapper for MlirBlockArgument.
1552436c6c9cSStella Laurenzo class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1553436c6c9cSStella Laurenzo public:
1554436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1555436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "BlockArgument";
1556436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1557436c6c9cSStella Laurenzo 
1558436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1559436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1560436c6c9cSStella Laurenzo       return PyBlock(self.getParentOperation(),
1561436c6c9cSStella Laurenzo                      mlirBlockArgumentGetOwner(self.get()));
1562436c6c9cSStella Laurenzo     });
1563436c6c9cSStella Laurenzo     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1564436c6c9cSStella Laurenzo       return mlirBlockArgumentGetArgNumber(self.get());
1565436c6c9cSStella Laurenzo     });
1566436c6c9cSStella Laurenzo     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1567436c6c9cSStella Laurenzo       return mlirBlockArgumentSetType(self.get(), type);
1568436c6c9cSStella Laurenzo     });
1569436c6c9cSStella Laurenzo   }
1570436c6c9cSStella Laurenzo };
1571436c6c9cSStella Laurenzo 
1572436c6c9cSStella Laurenzo /// Python wrapper for MlirOpResult.
1573436c6c9cSStella Laurenzo class PyOpResult : public PyConcreteValue<PyOpResult> {
1574436c6c9cSStella Laurenzo public:
1575436c6c9cSStella Laurenzo   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1576436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResult";
1577436c6c9cSStella Laurenzo   using PyConcreteValue::PyConcreteValue;
1578436c6c9cSStella Laurenzo 
1579436c6c9cSStella Laurenzo   static void bindDerived(ClassTy &c) {
1580436c6c9cSStella Laurenzo     c.def_property_readonly("owner", [](PyOpResult &self) {
1581436c6c9cSStella Laurenzo       assert(
1582436c6c9cSStella Laurenzo           mlirOperationEqual(self.getParentOperation()->get(),
1583436c6c9cSStella Laurenzo                              mlirOpResultGetOwner(self.get())) &&
1584436c6c9cSStella Laurenzo           "expected the owner of the value in Python to match that in the IR");
15856ff74f96SMike Urbach       return self.getParentOperation().getObject();
1586436c6c9cSStella Laurenzo     });
1587436c6c9cSStella Laurenzo     c.def_property_readonly("result_number", [](PyOpResult &self) {
1588436c6c9cSStella Laurenzo       return mlirOpResultGetResultNumber(self.get());
1589436c6c9cSStella Laurenzo     });
1590436c6c9cSStella Laurenzo   }
1591436c6c9cSStella Laurenzo };
1592436c6c9cSStella Laurenzo 
1593436c6c9cSStella Laurenzo /// A list of block arguments. Internally, these are stored as consecutive
1594436c6c9cSStella Laurenzo /// elements, random access is cheap. The argument list is associated with the
1595436c6c9cSStella Laurenzo /// operation that contains the block (detached blocks are not allowed in
1596436c6c9cSStella Laurenzo /// Python bindings) and extends its lifetime.
1597436c6c9cSStella Laurenzo class PyBlockArgumentList {
1598436c6c9cSStella Laurenzo public:
1599436c6c9cSStella Laurenzo   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1600436c6c9cSStella Laurenzo       : operation(std::move(operation)), block(block) {}
1601436c6c9cSStella Laurenzo 
1602436c6c9cSStella Laurenzo   /// Returns the length of the block argument list.
1603436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1604436c6c9cSStella Laurenzo     operation->checkValid();
1605436c6c9cSStella Laurenzo     return mlirBlockGetNumArguments(block);
1606436c6c9cSStella Laurenzo   }
1607436c6c9cSStella Laurenzo 
1608436c6c9cSStella Laurenzo   /// Returns `index`-th element of the block argument list.
1609436c6c9cSStella Laurenzo   PyBlockArgument dunderGetItem(intptr_t index) {
1610436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1611436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1612436c6c9cSStella Laurenzo                        "attempt to access out of bounds region");
1613436c6c9cSStella Laurenzo     }
1614436c6c9cSStella Laurenzo     PyValue value(operation, mlirBlockGetArgument(block, index));
1615436c6c9cSStella Laurenzo     return PyBlockArgument(value);
1616436c6c9cSStella Laurenzo   }
1617436c6c9cSStella Laurenzo 
1618436c6c9cSStella Laurenzo   /// Defines a Python class in the bindings.
1619436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1620f05ff4f7SStella Laurenzo     py::class_<PyBlockArgumentList>(m, "BlockArgumentList", py::module_local())
1621436c6c9cSStella Laurenzo         .def("__len__", &PyBlockArgumentList::dunderLen)
1622436c6c9cSStella Laurenzo         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1623436c6c9cSStella Laurenzo   }
1624436c6c9cSStella Laurenzo 
1625436c6c9cSStella Laurenzo private:
1626436c6c9cSStella Laurenzo   PyOperationRef operation;
1627436c6c9cSStella Laurenzo   MlirBlock block;
1628436c6c9cSStella Laurenzo };
1629436c6c9cSStella Laurenzo 
1630436c6c9cSStella Laurenzo /// A list of operation operands. Internally, these are stored as consecutive
1631436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1632436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1633436c6c9cSStella Laurenzo /// operation.
1634436c6c9cSStella Laurenzo class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1635436c6c9cSStella Laurenzo public:
1636436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpOperandList";
1637436c6c9cSStella Laurenzo 
1638436c6c9cSStella Laurenzo   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1639436c6c9cSStella Laurenzo                   intptr_t length = -1, intptr_t step = 1)
1640436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1641436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1642436c6c9cSStella Laurenzo                                : length,
1643436c6c9cSStella Laurenzo                   step),
1644436c6c9cSStella Laurenzo         operation(operation) {}
1645436c6c9cSStella Laurenzo 
1646436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1647436c6c9cSStella Laurenzo     operation->checkValid();
1648436c6c9cSStella Laurenzo     return mlirOperationGetNumOperands(operation->get());
1649436c6c9cSStella Laurenzo   }
1650436c6c9cSStella Laurenzo 
1651436c6c9cSStella Laurenzo   PyValue getElement(intptr_t pos) {
16525664c5e2SJohn Demme     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
16535664c5e2SJohn Demme     MlirOperation owner;
16545664c5e2SJohn Demme     if (mlirValueIsAOpResult(operand))
16555664c5e2SJohn Demme       owner = mlirOpResultGetOwner(operand);
16565664c5e2SJohn Demme     else if (mlirValueIsABlockArgument(operand))
16575664c5e2SJohn Demme       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
16585664c5e2SJohn Demme     else
16595664c5e2SJohn Demme       assert(false && "Value must be an block arg or op result.");
16605664c5e2SJohn Demme     PyOperationRef pyOwner =
16615664c5e2SJohn Demme         PyOperation::forOperation(operation->getContext(), owner);
16625664c5e2SJohn Demme     return PyValue(pyOwner, operand);
1663436c6c9cSStella Laurenzo   }
1664436c6c9cSStella Laurenzo 
1665436c6c9cSStella Laurenzo   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1666436c6c9cSStella Laurenzo     return PyOpOperandList(operation, startIndex, length, step);
1667436c6c9cSStella Laurenzo   }
1668436c6c9cSStella Laurenzo 
166963d16d06SMike Urbach   void dunderSetItem(intptr_t index, PyValue value) {
167063d16d06SMike Urbach     index = wrapIndex(index);
167163d16d06SMike Urbach     mlirOperationSetOperand(operation->get(), index, value.get());
167263d16d06SMike Urbach   }
167363d16d06SMike Urbach 
167463d16d06SMike Urbach   static void bindDerived(ClassTy &c) {
167563d16d06SMike Urbach     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
167663d16d06SMike Urbach   }
167763d16d06SMike Urbach 
1678436c6c9cSStella Laurenzo private:
1679436c6c9cSStella Laurenzo   PyOperationRef operation;
1680436c6c9cSStella Laurenzo };
1681436c6c9cSStella Laurenzo 
1682436c6c9cSStella Laurenzo /// A list of operation results. Internally, these are stored as consecutive
1683436c6c9cSStella Laurenzo /// elements, random access is cheap. The result list is associated with the
1684436c6c9cSStella Laurenzo /// operation whose results these are, and extends the lifetime of this
1685436c6c9cSStella Laurenzo /// operation.
1686436c6c9cSStella Laurenzo class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1687436c6c9cSStella Laurenzo public:
1688436c6c9cSStella Laurenzo   static constexpr const char *pyClassName = "OpResultList";
1689436c6c9cSStella Laurenzo 
1690436c6c9cSStella Laurenzo   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1691436c6c9cSStella Laurenzo                  intptr_t length = -1, intptr_t step = 1)
1692436c6c9cSStella Laurenzo       : Sliceable(startIndex,
1693436c6c9cSStella Laurenzo                   length == -1 ? mlirOperationGetNumResults(operation->get())
1694436c6c9cSStella Laurenzo                                : length,
1695436c6c9cSStella Laurenzo                   step),
1696436c6c9cSStella Laurenzo         operation(operation) {}
1697436c6c9cSStella Laurenzo 
1698436c6c9cSStella Laurenzo   intptr_t getNumElements() {
1699436c6c9cSStella Laurenzo     operation->checkValid();
1700436c6c9cSStella Laurenzo     return mlirOperationGetNumResults(operation->get());
1701436c6c9cSStella Laurenzo   }
1702436c6c9cSStella Laurenzo 
1703436c6c9cSStella Laurenzo   PyOpResult getElement(intptr_t index) {
1704436c6c9cSStella Laurenzo     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1705436c6c9cSStella Laurenzo     return PyOpResult(value);
1706436c6c9cSStella Laurenzo   }
1707436c6c9cSStella Laurenzo 
1708436c6c9cSStella Laurenzo   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1709436c6c9cSStella Laurenzo     return PyOpResultList(operation, startIndex, length, step);
1710436c6c9cSStella Laurenzo   }
1711436c6c9cSStella Laurenzo 
1712436c6c9cSStella Laurenzo private:
1713436c6c9cSStella Laurenzo   PyOperationRef operation;
1714436c6c9cSStella Laurenzo };
1715436c6c9cSStella Laurenzo 
1716436c6c9cSStella Laurenzo /// A list of operation attributes. Can be indexed by name, producing
1717436c6c9cSStella Laurenzo /// attributes, or by index, producing named attributes.
1718436c6c9cSStella Laurenzo class PyOpAttributeMap {
1719436c6c9cSStella Laurenzo public:
1720436c6c9cSStella Laurenzo   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1721436c6c9cSStella Laurenzo 
1722436c6c9cSStella Laurenzo   PyAttribute dunderGetItemNamed(const std::string &name) {
1723436c6c9cSStella Laurenzo     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1724436c6c9cSStella Laurenzo                                                          toMlirStringRef(name));
1725436c6c9cSStella Laurenzo     if (mlirAttributeIsNull(attr)) {
1726436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1727436c6c9cSStella Laurenzo                        "attempt to access a non-existent attribute");
1728436c6c9cSStella Laurenzo     }
1729436c6c9cSStella Laurenzo     return PyAttribute(operation->getContext(), attr);
1730436c6c9cSStella Laurenzo   }
1731436c6c9cSStella Laurenzo 
1732436c6c9cSStella Laurenzo   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1733436c6c9cSStella Laurenzo     if (index < 0 || index >= dunderLen()) {
1734436c6c9cSStella Laurenzo       throw SetPyError(PyExc_IndexError,
1735436c6c9cSStella Laurenzo                        "attempt to access out of bounds attribute");
1736436c6c9cSStella Laurenzo     }
1737436c6c9cSStella Laurenzo     MlirNamedAttribute namedAttr =
1738436c6c9cSStella Laurenzo         mlirOperationGetAttribute(operation->get(), index);
1739436c6c9cSStella Laurenzo     return PyNamedAttribute(
1740436c6c9cSStella Laurenzo         namedAttr.attribute,
1741436c6c9cSStella Laurenzo         std::string(mlirIdentifierStr(namedAttr.name).data));
1742436c6c9cSStella Laurenzo   }
1743436c6c9cSStella Laurenzo 
1744436c6c9cSStella Laurenzo   void dunderSetItem(const std::string &name, PyAttribute attr) {
1745436c6c9cSStella Laurenzo     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1746436c6c9cSStella Laurenzo                                     attr);
1747436c6c9cSStella Laurenzo   }
1748436c6c9cSStella Laurenzo 
1749436c6c9cSStella Laurenzo   void dunderDelItem(const std::string &name) {
1750436c6c9cSStella Laurenzo     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1751436c6c9cSStella Laurenzo                                                      toMlirStringRef(name));
1752436c6c9cSStella Laurenzo     if (!removed)
1753436c6c9cSStella Laurenzo       throw SetPyError(PyExc_KeyError,
1754436c6c9cSStella Laurenzo                        "attempt to delete a non-existent attribute");
1755436c6c9cSStella Laurenzo   }
1756436c6c9cSStella Laurenzo 
1757436c6c9cSStella Laurenzo   intptr_t dunderLen() {
1758436c6c9cSStella Laurenzo     return mlirOperationGetNumAttributes(operation->get());
1759436c6c9cSStella Laurenzo   }
1760436c6c9cSStella Laurenzo 
1761436c6c9cSStella Laurenzo   bool dunderContains(const std::string &name) {
1762436c6c9cSStella Laurenzo     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1763436c6c9cSStella Laurenzo         operation->get(), toMlirStringRef(name)));
1764436c6c9cSStella Laurenzo   }
1765436c6c9cSStella Laurenzo 
1766436c6c9cSStella Laurenzo   static void bind(py::module &m) {
1767f05ff4f7SStella Laurenzo     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1768436c6c9cSStella Laurenzo         .def("__contains__", &PyOpAttributeMap::dunderContains)
1769436c6c9cSStella Laurenzo         .def("__len__", &PyOpAttributeMap::dunderLen)
1770436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1771436c6c9cSStella Laurenzo         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1772436c6c9cSStella Laurenzo         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1773436c6c9cSStella Laurenzo         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1774436c6c9cSStella Laurenzo   }
1775436c6c9cSStella Laurenzo 
1776436c6c9cSStella Laurenzo private:
1777436c6c9cSStella Laurenzo   PyOperationRef operation;
1778436c6c9cSStella Laurenzo };
1779436c6c9cSStella Laurenzo 
1780436c6c9cSStella Laurenzo } // end namespace
1781436c6c9cSStella Laurenzo 
1782436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1783436c6c9cSStella Laurenzo // Populates the core exports of the 'ir' submodule.
1784436c6c9cSStella Laurenzo //------------------------------------------------------------------------------
1785436c6c9cSStella Laurenzo 
1786436c6c9cSStella Laurenzo void mlir::python::populateIRCore(py::module &m) {
1787436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
17884acd8457SAlex Zinenko   // Mapping of MlirContext.
1789436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1790f05ff4f7SStella Laurenzo   py::class_<PyMlirContext>(m, "Context", py::module_local())
1791436c6c9cSStella Laurenzo       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1792436c6c9cSStella Laurenzo       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1793436c6c9cSStella Laurenzo       .def("_get_context_again",
1794436c6c9cSStella Laurenzo            [](PyMlirContext &self) {
1795436c6c9cSStella Laurenzo              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1796436c6c9cSStella Laurenzo              return ref.releaseObject();
1797436c6c9cSStella Laurenzo            })
1798436c6c9cSStella Laurenzo       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1799436c6c9cSStella Laurenzo       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1800436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1801436c6c9cSStella Laurenzo                              &PyMlirContext::getCapsule)
1802436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1803436c6c9cSStella Laurenzo       .def("__enter__", &PyMlirContext::contextEnter)
1804436c6c9cSStella Laurenzo       .def("__exit__", &PyMlirContext::contextExit)
1805436c6c9cSStella Laurenzo       .def_property_readonly_static(
1806436c6c9cSStella Laurenzo           "current",
1807436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1808436c6c9cSStella Laurenzo             auto *context = PyThreadContextEntry::getDefaultContext();
1809436c6c9cSStella Laurenzo             if (!context)
1810436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Context");
1811436c6c9cSStella Laurenzo             return context;
1812436c6c9cSStella Laurenzo           },
1813436c6c9cSStella Laurenzo           "Gets the Context bound to the current thread or raises ValueError")
1814436c6c9cSStella Laurenzo       .def_property_readonly(
1815436c6c9cSStella Laurenzo           "dialects",
1816436c6c9cSStella Laurenzo           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1817436c6c9cSStella Laurenzo           "Gets a container for accessing dialects by name")
1818436c6c9cSStella Laurenzo       .def_property_readonly(
1819436c6c9cSStella Laurenzo           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1820436c6c9cSStella Laurenzo           "Alias for 'dialect'")
1821436c6c9cSStella Laurenzo       .def(
1822436c6c9cSStella Laurenzo           "get_dialect_descriptor",
1823436c6c9cSStella Laurenzo           [=](PyMlirContext &self, std::string &name) {
1824436c6c9cSStella Laurenzo             MlirDialect dialect = mlirContextGetOrLoadDialect(
1825436c6c9cSStella Laurenzo                 self.get(), {name.data(), name.size()});
1826436c6c9cSStella Laurenzo             if (mlirDialectIsNull(dialect)) {
1827436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
1828436c6c9cSStella Laurenzo                                Twine("Dialect '") + name + "' not found");
1829436c6c9cSStella Laurenzo             }
1830436c6c9cSStella Laurenzo             return PyDialectDescriptor(self.getRef(), dialect);
1831436c6c9cSStella Laurenzo           },
1832436c6c9cSStella Laurenzo           "Gets or loads a dialect by name, returning its descriptor object")
1833436c6c9cSStella Laurenzo       .def_property(
1834436c6c9cSStella Laurenzo           "allow_unregistered_dialects",
1835436c6c9cSStella Laurenzo           [](PyMlirContext &self) -> bool {
1836436c6c9cSStella Laurenzo             return mlirContextGetAllowUnregisteredDialects(self.get());
1837436c6c9cSStella Laurenzo           },
1838436c6c9cSStella Laurenzo           [](PyMlirContext &self, bool value) {
1839436c6c9cSStella Laurenzo             mlirContextSetAllowUnregisteredDialects(self.get(), value);
18409a9214faSStella Laurenzo           })
1841caa159f0SNicolas Vasilache       .def("enable_multithreading",
1842caa159f0SNicolas Vasilache            [](PyMlirContext &self, bool enable) {
1843caa159f0SNicolas Vasilache              mlirContextEnableMultithreading(self.get(), enable);
1844caa159f0SNicolas Vasilache            })
18459a9214faSStella Laurenzo       .def("is_registered_operation",
18469a9214faSStella Laurenzo            [](PyMlirContext &self, std::string &name) {
18479a9214faSStella Laurenzo              return mlirContextIsRegisteredOperation(
18489a9214faSStella Laurenzo                  self.get(), MlirStringRef{name.data(), name.size()});
1849436c6c9cSStella Laurenzo            });
1850436c6c9cSStella Laurenzo 
1851436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1852436c6c9cSStella Laurenzo   // Mapping of PyDialectDescriptor
1853436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1854f05ff4f7SStella Laurenzo   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1855436c6c9cSStella Laurenzo       .def_property_readonly("namespace",
1856436c6c9cSStella Laurenzo                              [](PyDialectDescriptor &self) {
1857436c6c9cSStella Laurenzo                                MlirStringRef ns =
1858436c6c9cSStella Laurenzo                                    mlirDialectGetNamespace(self.get());
1859436c6c9cSStella Laurenzo                                return py::str(ns.data, ns.length);
1860436c6c9cSStella Laurenzo                              })
1861436c6c9cSStella Laurenzo       .def("__repr__", [](PyDialectDescriptor &self) {
1862436c6c9cSStella Laurenzo         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1863436c6c9cSStella Laurenzo         std::string repr("<DialectDescriptor ");
1864436c6c9cSStella Laurenzo         repr.append(ns.data, ns.length);
1865436c6c9cSStella Laurenzo         repr.append(">");
1866436c6c9cSStella Laurenzo         return repr;
1867436c6c9cSStella Laurenzo       });
1868436c6c9cSStella Laurenzo 
1869436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1870436c6c9cSStella Laurenzo   // Mapping of PyDialects
1871436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1872f05ff4f7SStella Laurenzo   py::class_<PyDialects>(m, "Dialects", py::module_local())
1873436c6c9cSStella Laurenzo       .def("__getitem__",
1874436c6c9cSStella Laurenzo            [=](PyDialects &self, std::string keyName) {
1875436c6c9cSStella Laurenzo              MlirDialect dialect =
1876436c6c9cSStella Laurenzo                  self.getDialectForKey(keyName, /*attrError=*/false);
1877436c6c9cSStella Laurenzo              py::object descriptor =
1878436c6c9cSStella Laurenzo                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1879436c6c9cSStella Laurenzo              return createCustomDialectWrapper(keyName, std::move(descriptor));
1880436c6c9cSStella Laurenzo            })
1881436c6c9cSStella Laurenzo       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1882436c6c9cSStella Laurenzo         MlirDialect dialect =
1883436c6c9cSStella Laurenzo             self.getDialectForKey(attrName, /*attrError=*/true);
1884436c6c9cSStella Laurenzo         py::object descriptor =
1885436c6c9cSStella Laurenzo             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1886436c6c9cSStella Laurenzo         return createCustomDialectWrapper(attrName, std::move(descriptor));
1887436c6c9cSStella Laurenzo       });
1888436c6c9cSStella Laurenzo 
1889436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1890436c6c9cSStella Laurenzo   // Mapping of PyDialect
1891436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1892f05ff4f7SStella Laurenzo   py::class_<PyDialect>(m, "Dialect", py::module_local())
1893436c6c9cSStella Laurenzo       .def(py::init<py::object>(), "descriptor")
1894436c6c9cSStella Laurenzo       .def_property_readonly(
1895436c6c9cSStella Laurenzo           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1896436c6c9cSStella Laurenzo       .def("__repr__", [](py::object self) {
1897436c6c9cSStella Laurenzo         auto clazz = self.attr("__class__");
1898436c6c9cSStella Laurenzo         return py::str("<Dialect ") +
1899436c6c9cSStella Laurenzo                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1900436c6c9cSStella Laurenzo                clazz.attr("__module__") + py::str(".") +
1901436c6c9cSStella Laurenzo                clazz.attr("__name__") + py::str(")>");
1902436c6c9cSStella Laurenzo       });
1903436c6c9cSStella Laurenzo 
1904436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1905436c6c9cSStella Laurenzo   // Mapping of Location
1906436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1907f05ff4f7SStella Laurenzo   py::class_<PyLocation>(m, "Location", py::module_local())
1908436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1909436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1910436c6c9cSStella Laurenzo       .def("__enter__", &PyLocation::contextEnter)
1911436c6c9cSStella Laurenzo       .def("__exit__", &PyLocation::contextExit)
1912436c6c9cSStella Laurenzo       .def("__eq__",
1913436c6c9cSStella Laurenzo            [](PyLocation &self, PyLocation &other) -> bool {
1914436c6c9cSStella Laurenzo              return mlirLocationEqual(self, other);
1915436c6c9cSStella Laurenzo            })
1916436c6c9cSStella Laurenzo       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1917436c6c9cSStella Laurenzo       .def_property_readonly_static(
1918436c6c9cSStella Laurenzo           "current",
1919436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
1920436c6c9cSStella Laurenzo             auto *loc = PyThreadContextEntry::getDefaultLocation();
1921436c6c9cSStella Laurenzo             if (!loc)
1922436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current Location");
1923436c6c9cSStella Laurenzo             return loc;
1924436c6c9cSStella Laurenzo           },
1925436c6c9cSStella Laurenzo           "Gets the Location bound to the current thread or raises ValueError")
1926436c6c9cSStella Laurenzo       .def_static(
1927436c6c9cSStella Laurenzo           "unknown",
1928436c6c9cSStella Laurenzo           [](DefaultingPyMlirContext context) {
1929436c6c9cSStella Laurenzo             return PyLocation(context->getRef(),
1930436c6c9cSStella Laurenzo                               mlirLocationUnknownGet(context->get()));
1931436c6c9cSStella Laurenzo           },
1932436c6c9cSStella Laurenzo           py::arg("context") = py::none(),
1933436c6c9cSStella Laurenzo           "Gets a Location representing an unknown location")
1934436c6c9cSStella Laurenzo       .def_static(
1935436c6c9cSStella Laurenzo           "file",
1936436c6c9cSStella Laurenzo           [](std::string filename, int line, int col,
1937436c6c9cSStella Laurenzo              DefaultingPyMlirContext context) {
1938436c6c9cSStella Laurenzo             return PyLocation(
1939436c6c9cSStella Laurenzo                 context->getRef(),
1940436c6c9cSStella Laurenzo                 mlirLocationFileLineColGet(
1941436c6c9cSStella Laurenzo                     context->get(), toMlirStringRef(filename), line, col));
1942436c6c9cSStella Laurenzo           },
1943436c6c9cSStella Laurenzo           py::arg("filename"), py::arg("line"), py::arg("col"),
1944436c6c9cSStella Laurenzo           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1945436c6c9cSStella Laurenzo       .def_property_readonly(
1946436c6c9cSStella Laurenzo           "context",
1947436c6c9cSStella Laurenzo           [](PyLocation &self) { return self.getContext().getObject(); },
1948436c6c9cSStella Laurenzo           "Context that owns the Location")
1949436c6c9cSStella Laurenzo       .def("__repr__", [](PyLocation &self) {
1950436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
1951436c6c9cSStella Laurenzo         mlirLocationPrint(self, printAccum.getCallback(),
1952436c6c9cSStella Laurenzo                           printAccum.getUserData());
1953436c6c9cSStella Laurenzo         return printAccum.join();
1954436c6c9cSStella Laurenzo       });
1955436c6c9cSStella Laurenzo 
1956436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1957436c6c9cSStella Laurenzo   // Mapping of Module
1958436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
1959f05ff4f7SStella Laurenzo   py::class_<PyModule>(m, "Module", py::module_local())
1960436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1961436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1962436c6c9cSStella Laurenzo       .def_static(
1963436c6c9cSStella Laurenzo           "parse",
1964436c6c9cSStella Laurenzo           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1965436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateParse(
1966436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(moduleAsm));
1967436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
1968436c6c9cSStella Laurenzo             // in C API.
1969436c6c9cSStella Laurenzo             if (mlirModuleIsNull(module)) {
1970436c6c9cSStella Laurenzo               throw SetPyError(
1971436c6c9cSStella Laurenzo                   PyExc_ValueError,
1972436c6c9cSStella Laurenzo                   "Unable to parse module assembly (see diagnostics)");
1973436c6c9cSStella Laurenzo             }
1974436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
1975436c6c9cSStella Laurenzo           },
1976436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
1977436c6c9cSStella Laurenzo           kModuleParseDocstring)
1978436c6c9cSStella Laurenzo       .def_static(
1979436c6c9cSStella Laurenzo           "create",
1980436c6c9cSStella Laurenzo           [](DefaultingPyLocation loc) {
1981436c6c9cSStella Laurenzo             MlirModule module = mlirModuleCreateEmpty(loc);
1982436c6c9cSStella Laurenzo             return PyModule::forModule(module).releaseObject();
1983436c6c9cSStella Laurenzo           },
1984436c6c9cSStella Laurenzo           py::arg("loc") = py::none(), "Creates an empty module")
1985436c6c9cSStella Laurenzo       .def_property_readonly(
1986436c6c9cSStella Laurenzo           "context",
1987436c6c9cSStella Laurenzo           [](PyModule &self) { return self.getContext().getObject(); },
1988436c6c9cSStella Laurenzo           "Context that created the Module")
1989436c6c9cSStella Laurenzo       .def_property_readonly(
1990436c6c9cSStella Laurenzo           "operation",
1991436c6c9cSStella Laurenzo           [](PyModule &self) {
1992436c6c9cSStella Laurenzo             return PyOperation::forOperation(self.getContext(),
1993436c6c9cSStella Laurenzo                                              mlirModuleGetOperation(self.get()),
1994436c6c9cSStella Laurenzo                                              self.getRef().releaseObject())
1995436c6c9cSStella Laurenzo                 .releaseObject();
1996436c6c9cSStella Laurenzo           },
1997436c6c9cSStella Laurenzo           "Accesses the module as an operation")
1998436c6c9cSStella Laurenzo       .def_property_readonly(
1999436c6c9cSStella Laurenzo           "body",
2000436c6c9cSStella Laurenzo           [](PyModule &self) {
2001436c6c9cSStella Laurenzo             PyOperationRef module_op = PyOperation::forOperation(
2002436c6c9cSStella Laurenzo                 self.getContext(), mlirModuleGetOperation(self.get()),
2003436c6c9cSStella Laurenzo                 self.getRef().releaseObject());
2004436c6c9cSStella Laurenzo             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2005436c6c9cSStella Laurenzo             return returnBlock;
2006436c6c9cSStella Laurenzo           },
2007436c6c9cSStella Laurenzo           "Return the block for this module")
2008436c6c9cSStella Laurenzo       .def(
2009436c6c9cSStella Laurenzo           "dump",
2010436c6c9cSStella Laurenzo           [](PyModule &self) {
2011436c6c9cSStella Laurenzo             mlirOperationDump(mlirModuleGetOperation(self.get()));
2012436c6c9cSStella Laurenzo           },
2013436c6c9cSStella Laurenzo           kDumpDocstring)
2014436c6c9cSStella Laurenzo       .def(
2015436c6c9cSStella Laurenzo           "__str__",
2016436c6c9cSStella Laurenzo           [](PyModule &self) {
2017436c6c9cSStella Laurenzo             MlirOperation operation = mlirModuleGetOperation(self.get());
2018436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2019436c6c9cSStella Laurenzo             mlirOperationPrint(operation, printAccum.getCallback(),
2020436c6c9cSStella Laurenzo                                printAccum.getUserData());
2021436c6c9cSStella Laurenzo             return printAccum.join();
2022436c6c9cSStella Laurenzo           },
2023436c6c9cSStella Laurenzo           kOperationStrDunderDocstring);
2024436c6c9cSStella Laurenzo 
2025436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2026436c6c9cSStella Laurenzo   // Mapping of Operation.
2027436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2028f05ff4f7SStella Laurenzo   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
20291fb2e842SStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
20301fb2e842SStella Laurenzo                              [](PyOperationBase &self) {
20311fb2e842SStella Laurenzo                                return self.getOperation().getCapsule();
20321fb2e842SStella Laurenzo                              })
2033436c6c9cSStella Laurenzo       .def("__eq__",
2034436c6c9cSStella Laurenzo            [](PyOperationBase &self, PyOperationBase &other) {
2035436c6c9cSStella Laurenzo              return &self.getOperation() == &other.getOperation();
2036436c6c9cSStella Laurenzo            })
2037436c6c9cSStella Laurenzo       .def("__eq__",
2038436c6c9cSStella Laurenzo            [](PyOperationBase &self, py::object other) { return false; })
2039436c6c9cSStella Laurenzo       .def_property_readonly("attributes",
2040436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2041436c6c9cSStella Laurenzo                                return PyOpAttributeMap(
2042436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2043436c6c9cSStella Laurenzo                              })
2044436c6c9cSStella Laurenzo       .def_property_readonly("operands",
2045436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2046436c6c9cSStella Laurenzo                                return PyOpOperandList(
2047436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2048436c6c9cSStella Laurenzo                              })
2049436c6c9cSStella Laurenzo       .def_property_readonly("regions",
2050436c6c9cSStella Laurenzo                              [](PyOperationBase &self) {
2051436c6c9cSStella Laurenzo                                return PyRegionList(
2052436c6c9cSStella Laurenzo                                    self.getOperation().getRef());
2053436c6c9cSStella Laurenzo                              })
2054436c6c9cSStella Laurenzo       .def_property_readonly(
2055436c6c9cSStella Laurenzo           "results",
2056436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2057436c6c9cSStella Laurenzo             return PyOpResultList(self.getOperation().getRef());
2058436c6c9cSStella Laurenzo           },
2059436c6c9cSStella Laurenzo           "Returns the list of Operation results.")
2060436c6c9cSStella Laurenzo       .def_property_readonly(
2061436c6c9cSStella Laurenzo           "result",
2062436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2063436c6c9cSStella Laurenzo             auto &operation = self.getOperation();
2064436c6c9cSStella Laurenzo             auto numResults = mlirOperationGetNumResults(operation);
2065436c6c9cSStella Laurenzo             if (numResults != 1) {
2066436c6c9cSStella Laurenzo               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2067436c6c9cSStella Laurenzo               throw SetPyError(
2068436c6c9cSStella Laurenzo                   PyExc_ValueError,
2069436c6c9cSStella Laurenzo                   Twine("Cannot call .result on operation ") +
2070436c6c9cSStella Laurenzo                       StringRef(name.data, name.length) + " which has " +
2071436c6c9cSStella Laurenzo                       Twine(numResults) +
2072436c6c9cSStella Laurenzo                       " results (it is only valid for operations with a "
2073436c6c9cSStella Laurenzo                       "single result)");
2074436c6c9cSStella Laurenzo             }
2075436c6c9cSStella Laurenzo             return PyOpResult(operation.getRef(),
2076436c6c9cSStella Laurenzo                               mlirOperationGetResult(operation, 0));
2077436c6c9cSStella Laurenzo           },
2078436c6c9cSStella Laurenzo           "Shortcut to get an op result if it has only one (throws an error "
2079436c6c9cSStella Laurenzo           "otherwise).")
2080436c6c9cSStella Laurenzo       .def("__iter__",
2081436c6c9cSStella Laurenzo            [](PyOperationBase &self) {
2082436c6c9cSStella Laurenzo              return PyRegionIterator(self.getOperation().getRef());
2083436c6c9cSStella Laurenzo            })
2084436c6c9cSStella Laurenzo       .def(
2085436c6c9cSStella Laurenzo           "__str__",
2086436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2087436c6c9cSStella Laurenzo             return self.getAsm(/*binary=*/false,
2088436c6c9cSStella Laurenzo                                /*largeElementsLimit=*/llvm::None,
2089436c6c9cSStella Laurenzo                                /*enableDebugInfo=*/false,
2090436c6c9cSStella Laurenzo                                /*prettyDebugInfo=*/false,
2091436c6c9cSStella Laurenzo                                /*printGenericOpForm=*/false,
2092436c6c9cSStella Laurenzo                                /*useLocalScope=*/false);
2093436c6c9cSStella Laurenzo           },
2094436c6c9cSStella Laurenzo           "Returns the assembly form of the operation.")
2095436c6c9cSStella Laurenzo       .def("print", &PyOperationBase::print,
2096436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with print method.
2097436c6c9cSStella Laurenzo            py::arg("file") = py::none(), py::arg("binary") = false,
2098436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2099436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2100436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2101436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2102436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2103436c6c9cSStella Laurenzo       .def("get_asm", &PyOperationBase::getAsm,
2104436c6c9cSStella Laurenzo            // Careful: Lots of arguments must match up with get_asm method.
2105436c6c9cSStella Laurenzo            py::arg("binary") = false,
2106436c6c9cSStella Laurenzo            py::arg("large_elements_limit") = py::none(),
2107436c6c9cSStella Laurenzo            py::arg("enable_debug_info") = false,
2108436c6c9cSStella Laurenzo            py::arg("pretty_debug_info") = false,
2109436c6c9cSStella Laurenzo            py::arg("print_generic_op_form") = false,
2110436c6c9cSStella Laurenzo            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2111436c6c9cSStella Laurenzo       .def(
2112436c6c9cSStella Laurenzo           "verify",
2113436c6c9cSStella Laurenzo           [](PyOperationBase &self) {
2114436c6c9cSStella Laurenzo             return mlirOperationVerify(self.getOperation());
2115436c6c9cSStella Laurenzo           },
2116436c6c9cSStella Laurenzo           "Verify the operation and return true if it passes, false if it "
2117436c6c9cSStella Laurenzo           "fails.");
2118436c6c9cSStella Laurenzo 
2119f05ff4f7SStella Laurenzo   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2120436c6c9cSStella Laurenzo       .def_static("create", &PyOperation::create, py::arg("name"),
2121436c6c9cSStella Laurenzo                   py::arg("results") = py::none(),
2122436c6c9cSStella Laurenzo                   py::arg("operands") = py::none(),
2123436c6c9cSStella Laurenzo                   py::arg("attributes") = py::none(),
2124436c6c9cSStella Laurenzo                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2125436c6c9cSStella Laurenzo                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2126436c6c9cSStella Laurenzo                   kOperationCreateDocstring)
2127c65bb760SJohn Demme       .def_property_readonly("parent",
21281689dadeSJohn Demme                              [](PyOperation &self) -> py::object {
21291689dadeSJohn Demme                                auto parent = self.getParentOperation();
21301689dadeSJohn Demme                                if (parent)
21311689dadeSJohn Demme                                  return parent->getObject();
21321689dadeSJohn Demme                                return py::none();
2133c65bb760SJohn Demme                              })
213449745f87SMike Urbach       .def("erase", &PyOperation::erase)
21350126e906SJohn Demme       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
21360126e906SJohn Demme                              &PyOperation::getCapsule)
21370126e906SJohn Demme       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2138436c6c9cSStella Laurenzo       .def_property_readonly("name",
2139436c6c9cSStella Laurenzo                              [](PyOperation &self) {
214049745f87SMike Urbach                                self.checkValid();
2141436c6c9cSStella Laurenzo                                MlirOperation operation = self.get();
2142436c6c9cSStella Laurenzo                                MlirStringRef name = mlirIdentifierStr(
2143436c6c9cSStella Laurenzo                                    mlirOperationGetName(operation));
2144436c6c9cSStella Laurenzo                                return py::str(name.data, name.length);
2145436c6c9cSStella Laurenzo                              })
2146436c6c9cSStella Laurenzo       .def_property_readonly(
2147436c6c9cSStella Laurenzo           "context",
214849745f87SMike Urbach           [](PyOperation &self) {
214949745f87SMike Urbach             self.checkValid();
215049745f87SMike Urbach             return self.getContext().getObject();
215149745f87SMike Urbach           },
2152436c6c9cSStella Laurenzo           "Context that owns the Operation")
2153436c6c9cSStella Laurenzo       .def_property_readonly("opview", &PyOperation::createOpView);
2154436c6c9cSStella Laurenzo 
2155436c6c9cSStella Laurenzo   auto opViewClass =
2156f05ff4f7SStella Laurenzo       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2157436c6c9cSStella Laurenzo           .def(py::init<py::object>())
2158436c6c9cSStella Laurenzo           .def_property_readonly("operation", &PyOpView::getOperationObject)
2159436c6c9cSStella Laurenzo           .def_property_readonly(
2160436c6c9cSStella Laurenzo               "context",
2161436c6c9cSStella Laurenzo               [](PyOpView &self) {
2162436c6c9cSStella Laurenzo                 return self.getOperation().getContext().getObject();
2163436c6c9cSStella Laurenzo               },
2164436c6c9cSStella Laurenzo               "Context that owns the Operation")
2165436c6c9cSStella Laurenzo           .def("__str__", [](PyOpView &self) {
2166436c6c9cSStella Laurenzo             return py::str(self.getOperationObject());
2167436c6c9cSStella Laurenzo           });
2168436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2169436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2170436c6c9cSStella Laurenzo   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2171436c6c9cSStella Laurenzo   opViewClass.attr("build_generic") = classmethod(
2172436c6c9cSStella Laurenzo       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2173436c6c9cSStella Laurenzo       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2174436c6c9cSStella Laurenzo       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2175436c6c9cSStella Laurenzo       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2176436c6c9cSStella Laurenzo       "Builds a specific, generated OpView based on class level attributes.");
2177436c6c9cSStella Laurenzo 
2178436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2179436c6c9cSStella Laurenzo   // Mapping of PyRegion.
2180436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2181f05ff4f7SStella Laurenzo   py::class_<PyRegion>(m, "Region", py::module_local())
2182436c6c9cSStella Laurenzo       .def_property_readonly(
2183436c6c9cSStella Laurenzo           "blocks",
2184436c6c9cSStella Laurenzo           [](PyRegion &self) {
2185436c6c9cSStella Laurenzo             return PyBlockList(self.getParentOperation(), self.get());
2186436c6c9cSStella Laurenzo           },
2187436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of blocks.")
2188436c6c9cSStella Laurenzo       .def(
2189436c6c9cSStella Laurenzo           "__iter__",
2190436c6c9cSStella Laurenzo           [](PyRegion &self) {
2191436c6c9cSStella Laurenzo             self.checkValid();
2192436c6c9cSStella Laurenzo             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2193436c6c9cSStella Laurenzo             return PyBlockIterator(self.getParentOperation(), firstBlock);
2194436c6c9cSStella Laurenzo           },
2195436c6c9cSStella Laurenzo           "Iterates over blocks in the region.")
2196436c6c9cSStella Laurenzo       .def("__eq__",
2197436c6c9cSStella Laurenzo            [](PyRegion &self, PyRegion &other) {
2198436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2199436c6c9cSStella Laurenzo            })
2200436c6c9cSStella Laurenzo       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2201436c6c9cSStella Laurenzo 
2202436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2203436c6c9cSStella Laurenzo   // Mapping of PyBlock.
2204436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2205f05ff4f7SStella Laurenzo   py::class_<PyBlock>(m, "Block", py::module_local())
2206436c6c9cSStella Laurenzo       .def_property_readonly(
220796fbd5cdSJohn Demme           "owner",
220896fbd5cdSJohn Demme           [](PyBlock &self) {
220996fbd5cdSJohn Demme             return self.getParentOperation()->createOpView();
221096fbd5cdSJohn Demme           },
221196fbd5cdSJohn Demme           "Returns the owning operation of this block.")
221296fbd5cdSJohn Demme       .def_property_readonly(
22138e6c55c9SStella Laurenzo           "region",
22148e6c55c9SStella Laurenzo           [](PyBlock &self) {
22158e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
22168e6c55c9SStella Laurenzo             return PyRegion(self.getParentOperation(), region);
22178e6c55c9SStella Laurenzo           },
22188e6c55c9SStella Laurenzo           "Returns the owning region of this block.")
22198e6c55c9SStella Laurenzo       .def_property_readonly(
2220436c6c9cSStella Laurenzo           "arguments",
2221436c6c9cSStella Laurenzo           [](PyBlock &self) {
2222436c6c9cSStella Laurenzo             return PyBlockArgumentList(self.getParentOperation(), self.get());
2223436c6c9cSStella Laurenzo           },
2224436c6c9cSStella Laurenzo           "Returns a list of block arguments.")
2225436c6c9cSStella Laurenzo       .def_property_readonly(
2226436c6c9cSStella Laurenzo           "operations",
2227436c6c9cSStella Laurenzo           [](PyBlock &self) {
2228436c6c9cSStella Laurenzo             return PyOperationList(self.getParentOperation(), self.get());
2229436c6c9cSStella Laurenzo           },
2230436c6c9cSStella Laurenzo           "Returns a forward-optimized sequence of operations.")
2231436c6c9cSStella Laurenzo       .def(
22328e6c55c9SStella Laurenzo           "create_before",
22338e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
22348e6c55c9SStella Laurenzo             self.checkValid();
22358e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
22368e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
22378e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
22388e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
22398e6c55c9SStella Laurenzo             }
22408e6c55c9SStella Laurenzo 
22418e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
22428e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
22438e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
22448e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
22458e6c55c9SStella Laurenzo           },
22468e6c55c9SStella Laurenzo           "Creates and returns a new Block before this block "
22478e6c55c9SStella Laurenzo           "(with given argument types).")
22488e6c55c9SStella Laurenzo       .def(
22498e6c55c9SStella Laurenzo           "create_after",
22508e6c55c9SStella Laurenzo           [](PyBlock &self, py::args pyArgTypes) {
22518e6c55c9SStella Laurenzo             self.checkValid();
22528e6c55c9SStella Laurenzo             llvm::SmallVector<MlirType, 4> argTypes;
22538e6c55c9SStella Laurenzo             argTypes.reserve(pyArgTypes.size());
22548e6c55c9SStella Laurenzo             for (auto &pyArg : pyArgTypes) {
22558e6c55c9SStella Laurenzo               argTypes.push_back(pyArg.cast<PyType &>());
22568e6c55c9SStella Laurenzo             }
22578e6c55c9SStella Laurenzo 
22588e6c55c9SStella Laurenzo             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
22598e6c55c9SStella Laurenzo             MlirRegion region = mlirBlockGetParentRegion(self.get());
22608e6c55c9SStella Laurenzo             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
22618e6c55c9SStella Laurenzo             return PyBlock(self.getParentOperation(), block);
22628e6c55c9SStella Laurenzo           },
22638e6c55c9SStella Laurenzo           "Creates and returns a new Block after this block "
22648e6c55c9SStella Laurenzo           "(with given argument types).")
22658e6c55c9SStella Laurenzo       .def(
2266436c6c9cSStella Laurenzo           "__iter__",
2267436c6c9cSStella Laurenzo           [](PyBlock &self) {
2268436c6c9cSStella Laurenzo             self.checkValid();
2269436c6c9cSStella Laurenzo             MlirOperation firstOperation =
2270436c6c9cSStella Laurenzo                 mlirBlockGetFirstOperation(self.get());
2271436c6c9cSStella Laurenzo             return PyOperationIterator(self.getParentOperation(),
2272436c6c9cSStella Laurenzo                                        firstOperation);
2273436c6c9cSStella Laurenzo           },
2274436c6c9cSStella Laurenzo           "Iterates over operations in the block.")
2275436c6c9cSStella Laurenzo       .def("__eq__",
2276436c6c9cSStella Laurenzo            [](PyBlock &self, PyBlock &other) {
2277436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2278436c6c9cSStella Laurenzo            })
2279436c6c9cSStella Laurenzo       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2280436c6c9cSStella Laurenzo       .def(
2281436c6c9cSStella Laurenzo           "__str__",
2282436c6c9cSStella Laurenzo           [](PyBlock &self) {
2283436c6c9cSStella Laurenzo             self.checkValid();
2284436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2285436c6c9cSStella Laurenzo             mlirBlockPrint(self.get(), printAccum.getCallback(),
2286436c6c9cSStella Laurenzo                            printAccum.getUserData());
2287436c6c9cSStella Laurenzo             return printAccum.join();
2288436c6c9cSStella Laurenzo           },
2289436c6c9cSStella Laurenzo           "Returns the assembly form of the block.");
2290436c6c9cSStella Laurenzo 
2291436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2292436c6c9cSStella Laurenzo   // Mapping of PyInsertionPoint.
2293436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2294436c6c9cSStella Laurenzo 
2295f05ff4f7SStella Laurenzo   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2296436c6c9cSStella Laurenzo       .def(py::init<PyBlock &>(), py::arg("block"),
2297436c6c9cSStella Laurenzo            "Inserts after the last operation but still inside the block.")
2298436c6c9cSStella Laurenzo       .def("__enter__", &PyInsertionPoint::contextEnter)
2299436c6c9cSStella Laurenzo       .def("__exit__", &PyInsertionPoint::contextExit)
2300436c6c9cSStella Laurenzo       .def_property_readonly_static(
2301436c6c9cSStella Laurenzo           "current",
2302436c6c9cSStella Laurenzo           [](py::object & /*class*/) {
2303436c6c9cSStella Laurenzo             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2304436c6c9cSStella Laurenzo             if (!ip)
2305436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2306436c6c9cSStella Laurenzo             return ip;
2307436c6c9cSStella Laurenzo           },
2308436c6c9cSStella Laurenzo           "Gets the InsertionPoint bound to the current thread or raises "
2309436c6c9cSStella Laurenzo           "ValueError if none has been set")
2310436c6c9cSStella Laurenzo       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2311436c6c9cSStella Laurenzo            "Inserts before a referenced operation.")
2312436c6c9cSStella Laurenzo       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2313436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts at the beginning of the block.")
2314436c6c9cSStella Laurenzo       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2315436c6c9cSStella Laurenzo                   py::arg("block"), "Inserts before the block terminator.")
2316436c6c9cSStella Laurenzo       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
23178e6c55c9SStella Laurenzo            "Inserts an operation.")
23188e6c55c9SStella Laurenzo       .def_property_readonly(
23198e6c55c9SStella Laurenzo           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
23208e6c55c9SStella Laurenzo           "Returns the block that this InsertionPoint points to.");
2321436c6c9cSStella Laurenzo 
2322436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2323436c6c9cSStella Laurenzo   // Mapping of PyAttribute.
2324436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2325f05ff4f7SStella Laurenzo   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2326b57d6fe4SStella Laurenzo       // Delegate to the PyAttribute copy constructor, which will also lifetime
2327b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirAttribute.
2328b57d6fe4SStella Laurenzo       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2329b57d6fe4SStella Laurenzo            "Casts the passed attribute to the generic Attribute")
2330436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2331436c6c9cSStella Laurenzo                              &PyAttribute::getCapsule)
2332436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2333436c6c9cSStella Laurenzo       .def_static(
2334436c6c9cSStella Laurenzo           "parse",
2335436c6c9cSStella Laurenzo           [](std::string attrSpec, DefaultingPyMlirContext context) {
2336436c6c9cSStella Laurenzo             MlirAttribute type = mlirAttributeParseGet(
2337436c6c9cSStella Laurenzo                 context->get(), toMlirStringRef(attrSpec));
2338436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2339436c6c9cSStella Laurenzo             // in C API.
2340436c6c9cSStella Laurenzo             if (mlirAttributeIsNull(type)) {
2341436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2342436c6c9cSStella Laurenzo                                Twine("Unable to parse attribute: '") +
2343436c6c9cSStella Laurenzo                                    attrSpec + "'");
2344436c6c9cSStella Laurenzo             }
2345436c6c9cSStella Laurenzo             return PyAttribute(context->getRef(), type);
2346436c6c9cSStella Laurenzo           },
2347436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2348436c6c9cSStella Laurenzo           "Parses an attribute from an assembly form")
2349436c6c9cSStella Laurenzo       .def_property_readonly(
2350436c6c9cSStella Laurenzo           "context",
2351436c6c9cSStella Laurenzo           [](PyAttribute &self) { return self.getContext().getObject(); },
2352436c6c9cSStella Laurenzo           "Context that owns the Attribute")
2353436c6c9cSStella Laurenzo       .def_property_readonly("type",
2354436c6c9cSStella Laurenzo                              [](PyAttribute &self) {
2355436c6c9cSStella Laurenzo                                return PyType(self.getContext()->getRef(),
2356436c6c9cSStella Laurenzo                                              mlirAttributeGetType(self));
2357436c6c9cSStella Laurenzo                              })
2358436c6c9cSStella Laurenzo       .def(
2359436c6c9cSStella Laurenzo           "get_named",
2360436c6c9cSStella Laurenzo           [](PyAttribute &self, std::string name) {
2361436c6c9cSStella Laurenzo             return PyNamedAttribute(self, std::move(name));
2362436c6c9cSStella Laurenzo           },
2363436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2364436c6c9cSStella Laurenzo       .def("__eq__",
2365436c6c9cSStella Laurenzo            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2366436c6c9cSStella Laurenzo       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2367*47cc166bSJohn Demme       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2368436c6c9cSStella Laurenzo       .def(
2369436c6c9cSStella Laurenzo           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2370436c6c9cSStella Laurenzo           kDumpDocstring)
2371436c6c9cSStella Laurenzo       .def(
2372436c6c9cSStella Laurenzo           "__str__",
2373436c6c9cSStella Laurenzo           [](PyAttribute &self) {
2374436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2375436c6c9cSStella Laurenzo             mlirAttributePrint(self, printAccum.getCallback(),
2376436c6c9cSStella Laurenzo                                printAccum.getUserData());
2377436c6c9cSStella Laurenzo             return printAccum.join();
2378436c6c9cSStella Laurenzo           },
2379436c6c9cSStella Laurenzo           "Returns the assembly form of the Attribute.")
2380436c6c9cSStella Laurenzo       .def("__repr__", [](PyAttribute &self) {
2381436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2382436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2383436c6c9cSStella Laurenzo         // However, attribute values are generally considered useful and are
2384436c6c9cSStella Laurenzo         // printed. This may need to be re-evaluated if debug dumps end up
2385436c6c9cSStella Laurenzo         // being excessive.
2386436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2387436c6c9cSStella Laurenzo         printAccum.parts.append("Attribute(");
2388436c6c9cSStella Laurenzo         mlirAttributePrint(self, printAccum.getCallback(),
2389436c6c9cSStella Laurenzo                            printAccum.getUserData());
2390436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2391436c6c9cSStella Laurenzo         return printAccum.join();
2392436c6c9cSStella Laurenzo       });
2393436c6c9cSStella Laurenzo 
2394436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2395436c6c9cSStella Laurenzo   // Mapping of PyNamedAttribute
2396436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2397f05ff4f7SStella Laurenzo   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2398436c6c9cSStella Laurenzo       .def("__repr__",
2399436c6c9cSStella Laurenzo            [](PyNamedAttribute &self) {
2400436c6c9cSStella Laurenzo              PyPrintAccumulator printAccum;
2401436c6c9cSStella Laurenzo              printAccum.parts.append("NamedAttribute(");
2402436c6c9cSStella Laurenzo              printAccum.parts.append(
2403436c6c9cSStella Laurenzo                  mlirIdentifierStr(self.namedAttr.name).data);
2404436c6c9cSStella Laurenzo              printAccum.parts.append("=");
2405436c6c9cSStella Laurenzo              mlirAttributePrint(self.namedAttr.attribute,
2406436c6c9cSStella Laurenzo                                 printAccum.getCallback(),
2407436c6c9cSStella Laurenzo                                 printAccum.getUserData());
2408436c6c9cSStella Laurenzo              printAccum.parts.append(")");
2409436c6c9cSStella Laurenzo              return printAccum.join();
2410436c6c9cSStella Laurenzo            })
2411436c6c9cSStella Laurenzo       .def_property_readonly(
2412436c6c9cSStella Laurenzo           "name",
2413436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2414436c6c9cSStella Laurenzo             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2415436c6c9cSStella Laurenzo                            mlirIdentifierStr(self.namedAttr.name).length);
2416436c6c9cSStella Laurenzo           },
2417436c6c9cSStella Laurenzo           "The name of the NamedAttribute binding")
2418436c6c9cSStella Laurenzo       .def_property_readonly(
2419436c6c9cSStella Laurenzo           "attr",
2420436c6c9cSStella Laurenzo           [](PyNamedAttribute &self) {
2421436c6c9cSStella Laurenzo             // TODO: When named attribute is removed/refactored, also remove
2422436c6c9cSStella Laurenzo             // this constructor (it does an inefficient table lookup).
2423436c6c9cSStella Laurenzo             auto contextRef = PyMlirContext::forContext(
2424436c6c9cSStella Laurenzo                 mlirAttributeGetContext(self.namedAttr.attribute));
2425436c6c9cSStella Laurenzo             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2426436c6c9cSStella Laurenzo           },
2427436c6c9cSStella Laurenzo           py::keep_alive<0, 1>(),
2428436c6c9cSStella Laurenzo           "The underlying generic attribute of the NamedAttribute binding");
2429436c6c9cSStella Laurenzo 
2430436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2431436c6c9cSStella Laurenzo   // Mapping of PyType.
2432436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2433f05ff4f7SStella Laurenzo   py::class_<PyType>(m, "Type", py::module_local())
2434b57d6fe4SStella Laurenzo       // Delegate to the PyType copy constructor, which will also lifetime
2435b57d6fe4SStella Laurenzo       // extend the backing context which owns the MlirType.
2436b57d6fe4SStella Laurenzo       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2437b57d6fe4SStella Laurenzo            "Casts the passed type to the generic Type")
2438436c6c9cSStella Laurenzo       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2439436c6c9cSStella Laurenzo       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2440436c6c9cSStella Laurenzo       .def_static(
2441436c6c9cSStella Laurenzo           "parse",
2442436c6c9cSStella Laurenzo           [](std::string typeSpec, DefaultingPyMlirContext context) {
2443436c6c9cSStella Laurenzo             MlirType type =
2444436c6c9cSStella Laurenzo                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2445436c6c9cSStella Laurenzo             // TODO: Rework error reporting once diagnostic engine is exposed
2446436c6c9cSStella Laurenzo             // in C API.
2447436c6c9cSStella Laurenzo             if (mlirTypeIsNull(type)) {
2448436c6c9cSStella Laurenzo               throw SetPyError(PyExc_ValueError,
2449436c6c9cSStella Laurenzo                                Twine("Unable to parse type: '") + typeSpec +
2450436c6c9cSStella Laurenzo                                    "'");
2451436c6c9cSStella Laurenzo             }
2452436c6c9cSStella Laurenzo             return PyType(context->getRef(), type);
2453436c6c9cSStella Laurenzo           },
2454436c6c9cSStella Laurenzo           py::arg("asm"), py::arg("context") = py::none(),
2455436c6c9cSStella Laurenzo           kContextParseTypeDocstring)
2456436c6c9cSStella Laurenzo       .def_property_readonly(
2457436c6c9cSStella Laurenzo           "context", [](PyType &self) { return self.getContext().getObject(); },
2458436c6c9cSStella Laurenzo           "Context that owns the Type")
2459436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2460436c6c9cSStella Laurenzo       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2461*47cc166bSJohn Demme       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2462436c6c9cSStella Laurenzo       .def(
2463436c6c9cSStella Laurenzo           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2464436c6c9cSStella Laurenzo       .def(
2465436c6c9cSStella Laurenzo           "__str__",
2466436c6c9cSStella Laurenzo           [](PyType &self) {
2467436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2468436c6c9cSStella Laurenzo             mlirTypePrint(self, printAccum.getCallback(),
2469436c6c9cSStella Laurenzo                           printAccum.getUserData());
2470436c6c9cSStella Laurenzo             return printAccum.join();
2471436c6c9cSStella Laurenzo           },
2472436c6c9cSStella Laurenzo           "Returns the assembly form of the type.")
2473436c6c9cSStella Laurenzo       .def("__repr__", [](PyType &self) {
2474436c6c9cSStella Laurenzo         // Generally, assembly formats are not printed for __repr__ because
2475436c6c9cSStella Laurenzo         // this can cause exceptionally long debug output and exceptions.
2476436c6c9cSStella Laurenzo         // However, types are an exception as they typically have compact
2477436c6c9cSStella Laurenzo         // assembly forms and printing them is useful.
2478436c6c9cSStella Laurenzo         PyPrintAccumulator printAccum;
2479436c6c9cSStella Laurenzo         printAccum.parts.append("Type(");
2480436c6c9cSStella Laurenzo         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2481436c6c9cSStella Laurenzo         printAccum.parts.append(")");
2482436c6c9cSStella Laurenzo         return printAccum.join();
2483436c6c9cSStella Laurenzo       });
2484436c6c9cSStella Laurenzo 
2485436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2486436c6c9cSStella Laurenzo   // Mapping of Value.
2487436c6c9cSStella Laurenzo   //----------------------------------------------------------------------------
2488f05ff4f7SStella Laurenzo   py::class_<PyValue>(m, "Value", py::module_local())
24893f3d1c90SMike Urbach       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
24903f3d1c90SMike Urbach       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2491436c6c9cSStella Laurenzo       .def_property_readonly(
2492436c6c9cSStella Laurenzo           "context",
2493436c6c9cSStella Laurenzo           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2494436c6c9cSStella Laurenzo           "Context in which the value lives.")
2495436c6c9cSStella Laurenzo       .def(
2496436c6c9cSStella Laurenzo           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2497436c6c9cSStella Laurenzo           kDumpDocstring)
24985664c5e2SJohn Demme       .def_property_readonly(
24995664c5e2SJohn Demme           "owner",
25005664c5e2SJohn Demme           [](PyValue &self) {
25015664c5e2SJohn Demme             assert(mlirOperationEqual(self.getParentOperation()->get(),
25025664c5e2SJohn Demme                                       mlirOpResultGetOwner(self.get())) &&
25035664c5e2SJohn Demme                    "expected the owner of the value in Python to match that in "
25045664c5e2SJohn Demme                    "the IR");
25055664c5e2SJohn Demme             return self.getParentOperation().getObject();
25065664c5e2SJohn Demme           })
2507436c6c9cSStella Laurenzo       .def("__eq__",
2508436c6c9cSStella Laurenzo            [](PyValue &self, PyValue &other) {
2509436c6c9cSStella Laurenzo              return self.get().ptr == other.get().ptr;
2510436c6c9cSStella Laurenzo            })
2511436c6c9cSStella Laurenzo       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2512436c6c9cSStella Laurenzo       .def(
2513436c6c9cSStella Laurenzo           "__str__",
2514436c6c9cSStella Laurenzo           [](PyValue &self) {
2515436c6c9cSStella Laurenzo             PyPrintAccumulator printAccum;
2516436c6c9cSStella Laurenzo             printAccum.parts.append("Value(");
2517436c6c9cSStella Laurenzo             mlirValuePrint(self.get(), printAccum.getCallback(),
2518436c6c9cSStella Laurenzo                            printAccum.getUserData());
2519436c6c9cSStella Laurenzo             printAccum.parts.append(")");
2520436c6c9cSStella Laurenzo             return printAccum.join();
2521436c6c9cSStella Laurenzo           },
2522436c6c9cSStella Laurenzo           kValueDunderStrDocstring)
2523436c6c9cSStella Laurenzo       .def_property_readonly("type", [](PyValue &self) {
2524436c6c9cSStella Laurenzo         return PyType(self.getParentOperation()->getContext(),
2525436c6c9cSStella Laurenzo                       mlirValueGetType(self.get()));
2526436c6c9cSStella Laurenzo       });
2527436c6c9cSStella Laurenzo   PyBlockArgument::bind(m);
2528436c6c9cSStella Laurenzo   PyOpResult::bind(m);
2529436c6c9cSStella Laurenzo 
2530436c6c9cSStella Laurenzo   // Container bindings.
2531436c6c9cSStella Laurenzo   PyBlockArgumentList::bind(m);
2532436c6c9cSStella Laurenzo   PyBlockIterator::bind(m);
2533436c6c9cSStella Laurenzo   PyBlockList::bind(m);
2534436c6c9cSStella Laurenzo   PyOperationIterator::bind(m);
2535436c6c9cSStella Laurenzo   PyOperationList::bind(m);
2536436c6c9cSStella Laurenzo   PyOpAttributeMap::bind(m);
2537436c6c9cSStella Laurenzo   PyOpOperandList::bind(m);
2538436c6c9cSStella Laurenzo   PyOpResultList::bind(m);
2539436c6c9cSStella Laurenzo   PyRegionIterator::bind(m);
2540436c6c9cSStella Laurenzo   PyRegionList::bind(m);
25414acd8457SAlex Zinenko 
25424acd8457SAlex Zinenko   // Debug bindings.
25434acd8457SAlex Zinenko   PyGlobalDebugFlag::bind(m);
2544436c6c9cSStella Laurenzo }
2545