1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kContextGetNameLocationDocString[] =
47     R"(Gets a Location representing a named location with optional child location)";
48 
49 static const char kModuleParseDocstring[] =
50     R"(Parses a module's assembly format from a string.
51 
52 Returns a new MlirModule or raises a ValueError if the parsing fails.
53 
54 See also: https://mlir.llvm.org/docs/LangRef/
55 )";
56 
57 static const char kOperationCreateDocstring[] =
58     R"(Creates a new operation.
59 
60 Args:
61   name: Operation name (e.g. "dialect.operation").
62   results: Sequence of Type representing op result types.
63   attributes: Dict of str:Attribute.
64   successors: List of Block for the operation's successors.
65   regions: Number of regions to create.
66   location: A Location object (defaults to resolve from context manager).
67   ip: An InsertionPoint (defaults to resolve from context manager or set to
68     False to disable insertion, even with an insertion point set in the
69     context manager).
70 Returns:
71   A new "detached" Operation object. Detached operations can be added
72   to blocks, which causes them to become "attached."
73 )";
74 
75 static const char kOperationPrintDocstring[] =
76     R"(Prints the assembly form of the operation to a file like object.
77 
78 Args:
79   file: The file like object to write to. Defaults to sys.stdout.
80   binary: Whether to write bytes (True) or str (False). Defaults to False.
81   large_elements_limit: Whether to elide elements attributes above this
82     number of elements. Defaults to None (no limit).
83   enable_debug_info: Whether to print debug/location information. Defaults
84     to False.
85   pretty_debug_info: Whether to format debug information for easier reading
86     by a human (warning: the result is unparseable).
87   print_generic_op_form: Whether to print the generic assembly forms of all
88     ops. Defaults to False.
89   use_local_Scope: Whether to print in a way that is more optimized for
90     multi-threaded access but may not be consistent with how the overall
91     module prints.
92 )";
93 
94 static const char kOperationGetAsmDocstring[] =
95     R"(Gets the assembly form of the operation with all options available.
96 
97 Args:
98   binary: Whether to return a bytes (True) or str (False) object. Defaults to
99     False.
100   ... others ...: See the print() method for common keyword arguments for
101     configuring the printout.
102 Returns:
103   Either a bytes or str object, depending on the setting of the 'binary'
104   argument.
105 )";
106 
107 static const char kOperationStrDunderDocstring[] =
108     R"(Gets the assembly form of the operation with default options.
109 
110 If more advanced control over the assembly formatting or I/O options is needed,
111 use the dedicated print or get_asm method, which supports keyword arguments to
112 customize behavior.
113 )";
114 
115 static const char kDumpDocstring[] =
116     R"(Dumps a debug representation of the object to stderr.)";
117 
118 static const char kAppendBlockDocstring[] =
119     R"(Appends a new block, with argument types as positional args.
120 
121 Returns:
122   The created block.
123 )";
124 
125 static const char kValueDunderStrDocstring[] =
126     R"(Returns the string form of the value.
127 
128 If the value is a block argument, this is the assembly form of its type and the
129 position in the argument list. If the value is an operation result, this is
130 equivalent to printing the operation that produced it.
131 )";
132 
133 //------------------------------------------------------------------------------
134 // Utilities.
135 //------------------------------------------------------------------------------
136 
137 /// Helper for creating an @classmethod.
138 template <class Func, typename... Args>
139 py::object classmethod(Func f, Args... args) {
140   py::object cf = py::cpp_function(f, args...);
141   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
142 }
143 
144 static py::object
145 createCustomDialectWrapper(const std::string &dialectNamespace,
146                            py::object dialectDescriptor) {
147   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
148   if (!dialectClass) {
149     // Use the base class.
150     return py::cast(PyDialect(std::move(dialectDescriptor)));
151   }
152 
153   // Create the custom implementation.
154   return (*dialectClass)(std::move(dialectDescriptor));
155 }
156 
157 static MlirStringRef toMlirStringRef(const std::string &s) {
158   return mlirStringRefCreate(s.data(), s.size());
159 }
160 
161 /// Wrapper for the global LLVM debugging flag.
162 struct PyGlobalDebugFlag {
163   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
164 
165   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
166 
167   static void bind(py::module &m) {
168     // Debug flags.
169     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
170         .def_property_static("flag", &PyGlobalDebugFlag::get,
171                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
172   }
173 };
174 
175 //------------------------------------------------------------------------------
176 // Collections.
177 //------------------------------------------------------------------------------
178 
179 namespace {
180 
181 class PyRegionIterator {
182 public:
183   PyRegionIterator(PyOperationRef operation)
184       : operation(std::move(operation)) {}
185 
186   PyRegionIterator &dunderIter() { return *this; }
187 
188   PyRegion dunderNext() {
189     operation->checkValid();
190     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
191       throw py::stop_iteration();
192     }
193     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
194     return PyRegion(operation, region);
195   }
196 
197   static void bind(py::module &m) {
198     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
199         .def("__iter__", &PyRegionIterator::dunderIter)
200         .def("__next__", &PyRegionIterator::dunderNext);
201   }
202 
203 private:
204   PyOperationRef operation;
205   int nextIndex = 0;
206 };
207 
208 /// Regions of an op are fixed length and indexed numerically so are represented
209 /// with a sequence-like container.
210 class PyRegionList {
211 public:
212   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
213 
214   intptr_t dunderLen() {
215     operation->checkValid();
216     return mlirOperationGetNumRegions(operation->get());
217   }
218 
219   PyRegion dunderGetItem(intptr_t index) {
220     // dunderLen checks validity.
221     if (index < 0 || index >= dunderLen()) {
222       throw SetPyError(PyExc_IndexError,
223                        "attempt to access out of bounds region");
224     }
225     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
226     return PyRegion(operation, region);
227   }
228 
229   static void bind(py::module &m) {
230     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
231         .def("__len__", &PyRegionList::dunderLen)
232         .def("__getitem__", &PyRegionList::dunderGetItem);
233   }
234 
235 private:
236   PyOperationRef operation;
237 };
238 
239 class PyBlockIterator {
240 public:
241   PyBlockIterator(PyOperationRef operation, MlirBlock next)
242       : operation(std::move(operation)), next(next) {}
243 
244   PyBlockIterator &dunderIter() { return *this; }
245 
246   PyBlock dunderNext() {
247     operation->checkValid();
248     if (mlirBlockIsNull(next)) {
249       throw py::stop_iteration();
250     }
251 
252     PyBlock returnBlock(operation, next);
253     next = mlirBlockGetNextInRegion(next);
254     return returnBlock;
255   }
256 
257   static void bind(py::module &m) {
258     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
259         .def("__iter__", &PyBlockIterator::dunderIter)
260         .def("__next__", &PyBlockIterator::dunderNext);
261   }
262 
263 private:
264   PyOperationRef operation;
265   MlirBlock next;
266 };
267 
268 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
269 /// we present them as a more full-featured list-like container but optimize
270 /// it for forward iteration. Blocks are always owned by a region.
271 class PyBlockList {
272 public:
273   PyBlockList(PyOperationRef operation, MlirRegion region)
274       : operation(std::move(operation)), region(region) {}
275 
276   PyBlockIterator dunderIter() {
277     operation->checkValid();
278     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
279   }
280 
281   intptr_t dunderLen() {
282     operation->checkValid();
283     intptr_t count = 0;
284     MlirBlock block = mlirRegionGetFirstBlock(region);
285     while (!mlirBlockIsNull(block)) {
286       count += 1;
287       block = mlirBlockGetNextInRegion(block);
288     }
289     return count;
290   }
291 
292   PyBlock dunderGetItem(intptr_t index) {
293     operation->checkValid();
294     if (index < 0) {
295       throw SetPyError(PyExc_IndexError,
296                        "attempt to access out of bounds block");
297     }
298     MlirBlock block = mlirRegionGetFirstBlock(region);
299     while (!mlirBlockIsNull(block)) {
300       if (index == 0) {
301         return PyBlock(operation, block);
302       }
303       block = mlirBlockGetNextInRegion(block);
304       index -= 1;
305     }
306     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
307   }
308 
309   PyBlock appendBlock(py::args pyArgTypes) {
310     operation->checkValid();
311     llvm::SmallVector<MlirType, 4> argTypes;
312     argTypes.reserve(pyArgTypes.size());
313     for (auto &pyArg : pyArgTypes) {
314       argTypes.push_back(pyArg.cast<PyType &>());
315     }
316 
317     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
318     mlirRegionAppendOwnedBlock(region, block);
319     return PyBlock(operation, block);
320   }
321 
322   static void bind(py::module &m) {
323     py::class_<PyBlockList>(m, "BlockList", py::module_local())
324         .def("__getitem__", &PyBlockList::dunderGetItem)
325         .def("__iter__", &PyBlockList::dunderIter)
326         .def("__len__", &PyBlockList::dunderLen)
327         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
328   }
329 
330 private:
331   PyOperationRef operation;
332   MlirRegion region;
333 };
334 
335 class PyOperationIterator {
336 public:
337   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
338       : parentOperation(std::move(parentOperation)), next(next) {}
339 
340   PyOperationIterator &dunderIter() { return *this; }
341 
342   py::object dunderNext() {
343     parentOperation->checkValid();
344     if (mlirOperationIsNull(next)) {
345       throw py::stop_iteration();
346     }
347 
348     PyOperationRef returnOperation =
349         PyOperation::forOperation(parentOperation->getContext(), next);
350     next = mlirOperationGetNextInBlock(next);
351     return returnOperation->createOpView();
352   }
353 
354   static void bind(py::module &m) {
355     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
356         .def("__iter__", &PyOperationIterator::dunderIter)
357         .def("__next__", &PyOperationIterator::dunderNext);
358   }
359 
360 private:
361   PyOperationRef parentOperation;
362   MlirOperation next;
363 };
364 
365 /// Operations are exposed by the C-API as a forward-only linked list. In
366 /// Python, we present them as a more full-featured list-like container but
367 /// optimize it for forward iteration. Iterable operations are always owned
368 /// by a block.
369 class PyOperationList {
370 public:
371   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
372       : parentOperation(std::move(parentOperation)), block(block) {}
373 
374   PyOperationIterator dunderIter() {
375     parentOperation->checkValid();
376     return PyOperationIterator(parentOperation,
377                                mlirBlockGetFirstOperation(block));
378   }
379 
380   intptr_t dunderLen() {
381     parentOperation->checkValid();
382     intptr_t count = 0;
383     MlirOperation childOp = mlirBlockGetFirstOperation(block);
384     while (!mlirOperationIsNull(childOp)) {
385       count += 1;
386       childOp = mlirOperationGetNextInBlock(childOp);
387     }
388     return count;
389   }
390 
391   py::object dunderGetItem(intptr_t index) {
392     parentOperation->checkValid();
393     if (index < 0) {
394       throw SetPyError(PyExc_IndexError,
395                        "attempt to access out of bounds operation");
396     }
397     MlirOperation childOp = mlirBlockGetFirstOperation(block);
398     while (!mlirOperationIsNull(childOp)) {
399       if (index == 0) {
400         return PyOperation::forOperation(parentOperation->getContext(), childOp)
401             ->createOpView();
402       }
403       childOp = mlirOperationGetNextInBlock(childOp);
404       index -= 1;
405     }
406     throw SetPyError(PyExc_IndexError,
407                      "attempt to access out of bounds operation");
408   }
409 
410   static void bind(py::module &m) {
411     py::class_<PyOperationList>(m, "OperationList", py::module_local())
412         .def("__getitem__", &PyOperationList::dunderGetItem)
413         .def("__iter__", &PyOperationList::dunderIter)
414         .def("__len__", &PyOperationList::dunderLen);
415   }
416 
417 private:
418   PyOperationRef parentOperation;
419   MlirBlock block;
420 };
421 
422 } // namespace
423 
424 //------------------------------------------------------------------------------
425 // PyMlirContext
426 //------------------------------------------------------------------------------
427 
428 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
429   py::gil_scoped_acquire acquire;
430   auto &liveContexts = getLiveContexts();
431   liveContexts[context.ptr] = this;
432 }
433 
434 PyMlirContext::~PyMlirContext() {
435   // Note that the only public way to construct an instance is via the
436   // forContext method, which always puts the associated handle into
437   // liveContexts.
438   py::gil_scoped_acquire acquire;
439   getLiveContexts().erase(context.ptr);
440   mlirContextDestroy(context);
441 }
442 
443 py::object PyMlirContext::getCapsule() {
444   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
445 }
446 
447 py::object PyMlirContext::createFromCapsule(py::object capsule) {
448   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
449   if (mlirContextIsNull(rawContext))
450     throw py::error_already_set();
451   return forContext(rawContext).releaseObject();
452 }
453 
454 PyMlirContext *PyMlirContext::createNewContextForInit() {
455   MlirContext context = mlirContextCreate();
456   mlirRegisterAllDialects(context);
457   return new PyMlirContext(context);
458 }
459 
460 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
461   py::gil_scoped_acquire acquire;
462   auto &liveContexts = getLiveContexts();
463   auto it = liveContexts.find(context.ptr);
464   if (it == liveContexts.end()) {
465     // Create.
466     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
467     py::object pyRef = py::cast(unownedContextWrapper);
468     assert(pyRef && "cast to py::object failed");
469     liveContexts[context.ptr] = unownedContextWrapper;
470     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
471   }
472   // Use existing.
473   py::object pyRef = py::cast(it->second);
474   return PyMlirContextRef(it->second, std::move(pyRef));
475 }
476 
477 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
478   static LiveContextMap liveContexts;
479   return liveContexts;
480 }
481 
482 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
483 
484 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
485 
486 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
487 
488 pybind11::object PyMlirContext::contextEnter() {
489   return PyThreadContextEntry::pushContext(*this);
490 }
491 
492 void PyMlirContext::contextExit(pybind11::object excType,
493                                 pybind11::object excVal,
494                                 pybind11::object excTb) {
495   PyThreadContextEntry::popContext(*this);
496 }
497 
498 PyMlirContext &DefaultingPyMlirContext::resolve() {
499   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
500   if (!context) {
501     throw SetPyError(
502         PyExc_RuntimeError,
503         "An MLIR function requires a Context but none was provided in the call "
504         "or from the surrounding environment. Either pass to the function with "
505         "a 'context=' argument or establish a default using 'with Context():'");
506   }
507   return *context;
508 }
509 
510 //------------------------------------------------------------------------------
511 // PyThreadContextEntry management
512 //------------------------------------------------------------------------------
513 
514 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
515   static thread_local std::vector<PyThreadContextEntry> stack;
516   return stack;
517 }
518 
519 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
520   auto &stack = getStack();
521   if (stack.empty())
522     return nullptr;
523   return &stack.back();
524 }
525 
526 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
527                                 py::object insertionPoint,
528                                 py::object location) {
529   auto &stack = getStack();
530   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
531                      std::move(location));
532   // If the new stack has more than one entry and the context of the new top
533   // entry matches the previous, copy the insertionPoint and location from the
534   // previous entry if missing from the new top entry.
535   if (stack.size() > 1) {
536     auto &prev = *(stack.rbegin() + 1);
537     auto &current = stack.back();
538     if (current.context.is(prev.context)) {
539       // Default non-context objects from the previous entry.
540       if (!current.insertionPoint)
541         current.insertionPoint = prev.insertionPoint;
542       if (!current.location)
543         current.location = prev.location;
544     }
545   }
546 }
547 
548 PyMlirContext *PyThreadContextEntry::getContext() {
549   if (!context)
550     return nullptr;
551   return py::cast<PyMlirContext *>(context);
552 }
553 
554 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
555   if (!insertionPoint)
556     return nullptr;
557   return py::cast<PyInsertionPoint *>(insertionPoint);
558 }
559 
560 PyLocation *PyThreadContextEntry::getLocation() {
561   if (!location)
562     return nullptr;
563   return py::cast<PyLocation *>(location);
564 }
565 
566 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
567   auto *tos = getTopOfStack();
568   return tos ? tos->getContext() : nullptr;
569 }
570 
571 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
572   auto *tos = getTopOfStack();
573   return tos ? tos->getInsertionPoint() : nullptr;
574 }
575 
576 PyLocation *PyThreadContextEntry::getDefaultLocation() {
577   auto *tos = getTopOfStack();
578   return tos ? tos->getLocation() : nullptr;
579 }
580 
581 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
582   py::object contextObj = py::cast(context);
583   push(FrameKind::Context, /*context=*/contextObj,
584        /*insertionPoint=*/py::object(),
585        /*location=*/py::object());
586   return contextObj;
587 }
588 
589 void PyThreadContextEntry::popContext(PyMlirContext &context) {
590   auto &stack = getStack();
591   if (stack.empty())
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   auto &tos = stack.back();
594   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
595     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
596   stack.pop_back();
597 }
598 
599 py::object
600 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
601   py::object contextObj =
602       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
603   py::object insertionPointObj = py::cast(insertionPoint);
604   push(FrameKind::InsertionPoint,
605        /*context=*/contextObj,
606        /*insertionPoint=*/insertionPointObj,
607        /*location=*/py::object());
608   return insertionPointObj;
609 }
610 
611 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
612   auto &stack = getStack();
613   if (stack.empty())
614     throw SetPyError(PyExc_RuntimeError,
615                      "Unbalanced InsertionPoint enter/exit");
616   auto &tos = stack.back();
617   if (tos.frameKind != FrameKind::InsertionPoint &&
618       tos.getInsertionPoint() != &insertionPoint)
619     throw SetPyError(PyExc_RuntimeError,
620                      "Unbalanced InsertionPoint enter/exit");
621   stack.pop_back();
622 }
623 
624 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
625   py::object contextObj = location.getContext().getObject();
626   py::object locationObj = py::cast(location);
627   push(FrameKind::Location, /*context=*/contextObj,
628        /*insertionPoint=*/py::object(),
629        /*location=*/locationObj);
630   return locationObj;
631 }
632 
633 void PyThreadContextEntry::popLocation(PyLocation &location) {
634   auto &stack = getStack();
635   if (stack.empty())
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   auto &tos = stack.back();
638   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
639     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
640   stack.pop_back();
641 }
642 
643 //------------------------------------------------------------------------------
644 // PyDialect, PyDialectDescriptor, PyDialects
645 //------------------------------------------------------------------------------
646 
647 MlirDialect PyDialects::getDialectForKey(const std::string &key,
648                                          bool attrError) {
649   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
650                                                     {key.data(), key.size()});
651   if (mlirDialectIsNull(dialect)) {
652     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
653                      Twine("Dialect '") + key + "' not found");
654   }
655   return dialect;
656 }
657 
658 //------------------------------------------------------------------------------
659 // PyLocation
660 //------------------------------------------------------------------------------
661 
662 py::object PyLocation::getCapsule() {
663   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
664 }
665 
666 PyLocation PyLocation::createFromCapsule(py::object capsule) {
667   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
668   if (mlirLocationIsNull(rawLoc))
669     throw py::error_already_set();
670   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
671                     rawLoc);
672 }
673 
674 py::object PyLocation::contextEnter() {
675   return PyThreadContextEntry::pushLocation(*this);
676 }
677 
678 void PyLocation::contextExit(py::object excType, py::object excVal,
679                              py::object excTb) {
680   PyThreadContextEntry::popLocation(*this);
681 }
682 
683 PyLocation &DefaultingPyLocation::resolve() {
684   auto *location = PyThreadContextEntry::getDefaultLocation();
685   if (!location) {
686     throw SetPyError(
687         PyExc_RuntimeError,
688         "An MLIR function requires a Location but none was provided in the "
689         "call or from the surrounding environment. Either pass to the function "
690         "with a 'loc=' argument or establish a default using 'with loc:'");
691   }
692   return *location;
693 }
694 
695 //------------------------------------------------------------------------------
696 // PyModule
697 //------------------------------------------------------------------------------
698 
699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
700     : BaseContextObject(std::move(contextRef)), module(module) {}
701 
702 PyModule::~PyModule() {
703   py::gil_scoped_acquire acquire;
704   auto &liveModules = getContext()->liveModules;
705   assert(liveModules.count(module.ptr) == 1 &&
706          "destroying module not in live map");
707   liveModules.erase(module.ptr);
708   mlirModuleDestroy(module);
709 }
710 
711 PyModuleRef PyModule::forModule(MlirModule module) {
712   MlirContext context = mlirModuleGetContext(module);
713   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
714 
715   py::gil_scoped_acquire acquire;
716   auto &liveModules = contextRef->liveModules;
717   auto it = liveModules.find(module.ptr);
718   if (it == liveModules.end()) {
719     // Create.
720     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
721     // Note that the default return value policy on cast is automatic_reference,
722     // which does not take ownership (delete will not be called).
723     // Just be explicit.
724     py::object pyRef =
725         py::cast(unownedModule, py::return_value_policy::take_ownership);
726     unownedModule->handle = pyRef;
727     liveModules[module.ptr] =
728         std::make_pair(unownedModule->handle, unownedModule);
729     return PyModuleRef(unownedModule, std::move(pyRef));
730   }
731   // Use existing.
732   PyModule *existing = it->second.second;
733   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
734   return PyModuleRef(existing, std::move(pyRef));
735 }
736 
737 py::object PyModule::createFromCapsule(py::object capsule) {
738   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
739   if (mlirModuleIsNull(rawModule))
740     throw py::error_already_set();
741   return forModule(rawModule).releaseObject();
742 }
743 
744 py::object PyModule::getCapsule() {
745   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
746 }
747 
748 //------------------------------------------------------------------------------
749 // PyOperation
750 //------------------------------------------------------------------------------
751 
752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
753     : BaseContextObject(std::move(contextRef)), operation(operation) {}
754 
755 PyOperation::~PyOperation() {
756   // If the operation has already been invalidated there is nothing to do.
757   if (!valid)
758     return;
759   auto &liveOperations = getContext()->liveOperations;
760   assert(liveOperations.count(operation.ptr) == 1 &&
761          "destroying operation not in live map");
762   liveOperations.erase(operation.ptr);
763   if (!isAttached()) {
764     mlirOperationDestroy(operation);
765   }
766 }
767 
768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
769                                            MlirOperation operation,
770                                            py::object parentKeepAlive) {
771   auto &liveOperations = contextRef->liveOperations;
772   // Create.
773   PyOperation *unownedOperation =
774       new PyOperation(std::move(contextRef), operation);
775   // Note that the default return value policy on cast is automatic_reference,
776   // which does not take ownership (delete will not be called).
777   // Just be explicit.
778   py::object pyRef =
779       py::cast(unownedOperation, py::return_value_policy::take_ownership);
780   unownedOperation->handle = pyRef;
781   if (parentKeepAlive) {
782     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
783   }
784   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
785   return PyOperationRef(unownedOperation, std::move(pyRef));
786 }
787 
788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
789                                          MlirOperation operation,
790                                          py::object parentKeepAlive) {
791   auto &liveOperations = contextRef->liveOperations;
792   auto it = liveOperations.find(operation.ptr);
793   if (it == liveOperations.end()) {
794     // Create.
795     return createInstance(std::move(contextRef), operation,
796                           std::move(parentKeepAlive));
797   }
798   // Use existing.
799   PyOperation *existing = it->second.second;
800   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
801   return PyOperationRef(existing, std::move(pyRef));
802 }
803 
804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
805                                            MlirOperation operation,
806                                            py::object parentKeepAlive) {
807   auto &liveOperations = contextRef->liveOperations;
808   assert(liveOperations.count(operation.ptr) == 0 &&
809          "cannot create detached operation that already exists");
810   (void)liveOperations;
811 
812   PyOperationRef created = createInstance(std::move(contextRef), operation,
813                                           std::move(parentKeepAlive));
814   created->attached = false;
815   return created;
816 }
817 
818 void PyOperation::checkValid() const {
819   if (!valid) {
820     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
821   }
822 }
823 
824 void PyOperationBase::print(py::object fileObject, bool binary,
825                             llvm::Optional<int64_t> largeElementsLimit,
826                             bool enableDebugInfo, bool prettyDebugInfo,
827                             bool printGenericOpForm, bool useLocalScope) {
828   PyOperation &operation = getOperation();
829   operation.checkValid();
830   if (fileObject.is_none())
831     fileObject = py::module::import("sys").attr("stdout");
832 
833   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
834     fileObject.attr("write")("// Verification failed, printing generic form\n");
835     printGenericOpForm = true;
836   }
837 
838   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
839   if (largeElementsLimit)
840     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
841   if (enableDebugInfo)
842     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
843   if (printGenericOpForm)
844     mlirOpPrintingFlagsPrintGenericOpForm(flags);
845 
846   PyFileAccumulator accum(fileObject, binary);
847   py::gil_scoped_release();
848   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
849                               accum.getUserData());
850   mlirOpPrintingFlagsDestroy(flags);
851 }
852 
853 py::object PyOperationBase::getAsm(bool binary,
854                                    llvm::Optional<int64_t> largeElementsLimit,
855                                    bool enableDebugInfo, bool prettyDebugInfo,
856                                    bool printGenericOpForm,
857                                    bool useLocalScope) {
858   py::object fileObject;
859   if (binary) {
860     fileObject = py::module::import("io").attr("BytesIO")();
861   } else {
862     fileObject = py::module::import("io").attr("StringIO")();
863   }
864   print(fileObject, /*binary=*/binary,
865         /*largeElementsLimit=*/largeElementsLimit,
866         /*enableDebugInfo=*/enableDebugInfo,
867         /*prettyDebugInfo=*/prettyDebugInfo,
868         /*printGenericOpForm=*/printGenericOpForm,
869         /*useLocalScope=*/useLocalScope);
870 
871   return fileObject.attr("getvalue")();
872 }
873 
874 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
875   checkValid();
876   if (!isAttached())
877     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
878   MlirOperation operation = mlirOperationGetParentOperation(get());
879   if (mlirOperationIsNull(operation))
880     return {};
881   return PyOperation::forOperation(getContext(), operation);
882 }
883 
884 PyBlock PyOperation::getBlock() {
885   checkValid();
886   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
887   MlirBlock block = mlirOperationGetBlock(get());
888   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
889   assert(parentOperation && "Operation has no parent");
890   return PyBlock{std::move(*parentOperation), block};
891 }
892 
893 py::object PyOperation::getCapsule() {
894   checkValid();
895   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
896 }
897 
898 py::object PyOperation::createFromCapsule(py::object capsule) {
899   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
900   if (mlirOperationIsNull(rawOperation))
901     throw py::error_already_set();
902   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
903   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
904       .releaseObject();
905 }
906 
907 py::object PyOperation::create(
908     std::string name, llvm::Optional<std::vector<PyType *>> results,
909     llvm::Optional<std::vector<PyValue *>> operands,
910     llvm::Optional<py::dict> attributes,
911     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
912     DefaultingPyLocation location, py::object maybeIp) {
913   llvm::SmallVector<MlirValue, 4> mlirOperands;
914   llvm::SmallVector<MlirType, 4> mlirResults;
915   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
916   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
917 
918   // General parameter validation.
919   if (regions < 0)
920     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
921 
922   // Unpack/validate operands.
923   if (operands) {
924     mlirOperands.reserve(operands->size());
925     for (PyValue *operand : *operands) {
926       if (!operand)
927         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
928       mlirOperands.push_back(operand->get());
929     }
930   }
931 
932   // Unpack/validate results.
933   if (results) {
934     mlirResults.reserve(results->size());
935     for (PyType *result : *results) {
936       // TODO: Verify result type originate from the same context.
937       if (!result)
938         throw SetPyError(PyExc_ValueError, "result type cannot be None");
939       mlirResults.push_back(*result);
940     }
941   }
942   // Unpack/validate attributes.
943   if (attributes) {
944     mlirAttributes.reserve(attributes->size());
945     for (auto &it : *attributes) {
946       std::string key;
947       try {
948         key = it.first.cast<std::string>();
949       } catch (py::cast_error &err) {
950         std::string msg = "Invalid attribute key (not a string) when "
951                           "attempting to create the operation \"" +
952                           name + "\" (" + err.what() + ")";
953         throw py::cast_error(msg);
954       }
955       try {
956         auto &attribute = it.second.cast<PyAttribute &>();
957         // TODO: Verify attribute originates from the same context.
958         mlirAttributes.emplace_back(std::move(key), attribute);
959       } catch (py::reference_cast_error &) {
960         // This exception seems thrown when the value is "None".
961         std::string msg =
962             "Found an invalid (`None`?) attribute value for the key \"" + key +
963             "\" when attempting to create the operation \"" + name + "\"";
964         throw py::cast_error(msg);
965       } catch (py::cast_error &err) {
966         std::string msg = "Invalid attribute value for the key \"" + key +
967                           "\" when attempting to create the operation \"" +
968                           name + "\" (" + err.what() + ")";
969         throw py::cast_error(msg);
970       }
971     }
972   }
973   // Unpack/validate successors.
974   if (successors) {
975     mlirSuccessors.reserve(successors->size());
976     for (auto *successor : *successors) {
977       // TODO: Verify successor originate from the same context.
978       if (!successor)
979         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
980       mlirSuccessors.push_back(successor->get());
981     }
982   }
983 
984   // Apply unpacked/validated to the operation state. Beyond this
985   // point, exceptions cannot be thrown or else the state will leak.
986   MlirOperationState state =
987       mlirOperationStateGet(toMlirStringRef(name), location);
988   if (!mlirOperands.empty())
989     mlirOperationStateAddOperands(&state, mlirOperands.size(),
990                                   mlirOperands.data());
991   if (!mlirResults.empty())
992     mlirOperationStateAddResults(&state, mlirResults.size(),
993                                  mlirResults.data());
994   if (!mlirAttributes.empty()) {
995     // Note that the attribute names directly reference bytes in
996     // mlirAttributes, so that vector must not be changed from here
997     // on.
998     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
999     mlirNamedAttributes.reserve(mlirAttributes.size());
1000     for (auto &it : mlirAttributes)
1001       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1002           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1003                             toMlirStringRef(it.first)),
1004           it.second));
1005     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1006                                     mlirNamedAttributes.data());
1007   }
1008   if (!mlirSuccessors.empty())
1009     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1010                                     mlirSuccessors.data());
1011   if (regions) {
1012     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1013     mlirRegions.resize(regions);
1014     for (int i = 0; i < regions; ++i)
1015       mlirRegions[i] = mlirRegionCreate();
1016     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1017                                       mlirRegions.data());
1018   }
1019 
1020   // Construct the operation.
1021   MlirOperation operation = mlirOperationCreate(&state);
1022   PyOperationRef created =
1023       PyOperation::createDetached(location->getContext(), operation);
1024 
1025   // InsertPoint active?
1026   if (!maybeIp.is(py::cast(false))) {
1027     PyInsertionPoint *ip;
1028     if (maybeIp.is_none()) {
1029       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1030     } else {
1031       ip = py::cast<PyInsertionPoint *>(maybeIp);
1032     }
1033     if (ip)
1034       ip->insert(*created.get());
1035   }
1036 
1037   return created->createOpView();
1038 }
1039 
1040 py::object PyOperation::createOpView() {
1041   checkValid();
1042   MlirIdentifier ident = mlirOperationGetName(get());
1043   MlirStringRef identStr = mlirIdentifierStr(ident);
1044   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1045       StringRef(identStr.data, identStr.length));
1046   if (opViewClass)
1047     return (*opViewClass)(getRef().getObject());
1048   return py::cast(PyOpView(getRef().getObject()));
1049 }
1050 
1051 void PyOperation::erase() {
1052   checkValid();
1053   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1054   // Python reference to a child operation is live. All children should also
1055   // have their `valid` bit set to false.
1056   auto &liveOperations = getContext()->liveOperations;
1057   if (liveOperations.count(operation.ptr))
1058     liveOperations.erase(operation.ptr);
1059   mlirOperationDestroy(operation);
1060   valid = false;
1061 }
1062 
1063 //------------------------------------------------------------------------------
1064 // PyOpView
1065 //------------------------------------------------------------------------------
1066 
1067 py::object
1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1069                        py::list operandList,
1070                        llvm::Optional<py::dict> attributes,
1071                        llvm::Optional<std::vector<PyBlock *>> successors,
1072                        llvm::Optional<int> regions,
1073                        DefaultingPyLocation location, py::object maybeIp) {
1074   PyMlirContextRef context = location->getContext();
1075   // Class level operation construction metadata.
1076   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1077   // Operand and result segment specs are either none, which does no
1078   // variadic unpacking, or a list of ints with segment sizes, where each
1079   // element is either a positive number (typically 1 for a scalar) or -1 to
1080   // indicate that it is derived from the length of the same-indexed operand
1081   // or result (implying that it is a list at that position).
1082   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1083   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1084 
1085   std::vector<uint32_t> operandSegmentLengths;
1086   std::vector<uint32_t> resultSegmentLengths;
1087 
1088   // Validate/determine region count.
1089   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1090   int opMinRegionCount = std::get<0>(opRegionSpec);
1091   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1092   if (!regions) {
1093     regions = opMinRegionCount;
1094   }
1095   if (*regions < opMinRegionCount) {
1096     throw py::value_error(
1097         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1098          llvm::Twine(opMinRegionCount) +
1099          " regions but was built with regions=" + llvm::Twine(*regions))
1100             .str());
1101   }
1102   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1103     throw py::value_error(
1104         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1105          llvm::Twine(opMinRegionCount) +
1106          " regions but was built with regions=" + llvm::Twine(*regions))
1107             .str());
1108   }
1109 
1110   // Unpack results.
1111   std::vector<PyType *> resultTypes;
1112   resultTypes.reserve(resultTypeList.size());
1113   if (resultSegmentSpecObj.is_none()) {
1114     // Non-variadic result unpacking.
1115     for (auto it : llvm::enumerate(resultTypeList)) {
1116       try {
1117         resultTypes.push_back(py::cast<PyType *>(it.value()));
1118         if (!resultTypes.back())
1119           throw py::cast_error();
1120       } catch (py::cast_error &err) {
1121         throw py::value_error((llvm::Twine("Result ") +
1122                                llvm::Twine(it.index()) + " of operation \"" +
1123                                name + "\" must be a Type (" + err.what() + ")")
1124                                   .str());
1125       }
1126     }
1127   } else {
1128     // Sized result unpacking.
1129     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1130     if (resultSegmentSpec.size() != resultTypeList.size()) {
1131       throw py::value_error((llvm::Twine("Operation \"") + name +
1132                              "\" requires " +
1133                              llvm::Twine(resultSegmentSpec.size()) +
1134                              "result segments but was provided " +
1135                              llvm::Twine(resultTypeList.size()))
1136                                 .str());
1137     }
1138     resultSegmentLengths.reserve(resultTypeList.size());
1139     for (auto it :
1140          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1141       int segmentSpec = std::get<1>(it.value());
1142       if (segmentSpec == 1 || segmentSpec == 0) {
1143         // Unpack unary element.
1144         try {
1145           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1146           if (resultType) {
1147             resultTypes.push_back(resultType);
1148             resultSegmentLengths.push_back(1);
1149           } else if (segmentSpec == 0) {
1150             // Allowed to be optional.
1151             resultSegmentLengths.push_back(0);
1152           } else {
1153             throw py::cast_error("was None and result is not optional");
1154           }
1155         } catch (py::cast_error &err) {
1156           throw py::value_error((llvm::Twine("Result ") +
1157                                  llvm::Twine(it.index()) + " of operation \"" +
1158                                  name + "\" must be a Type (" + err.what() +
1159                                  ")")
1160                                     .str());
1161         }
1162       } else if (segmentSpec == -1) {
1163         // Unpack sequence by appending.
1164         try {
1165           if (std::get<0>(it.value()).is_none()) {
1166             // Treat it as an empty list.
1167             resultSegmentLengths.push_back(0);
1168           } else {
1169             // Unpack the list.
1170             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1171             for (py::object segmentItem : segment) {
1172               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1173               if (!resultTypes.back()) {
1174                 throw py::cast_error("contained a None item");
1175               }
1176             }
1177             resultSegmentLengths.push_back(segment.size());
1178           }
1179         } catch (std::exception &err) {
1180           // NOTE: Sloppy to be using a catch-all here, but there are at least
1181           // three different unrelated exceptions that can be thrown in the
1182           // above "casts". Just keep the scope above small and catch them all.
1183           throw py::value_error((llvm::Twine("Result ") +
1184                                  llvm::Twine(it.index()) + " of operation \"" +
1185                                  name + "\" must be a Sequence of Types (" +
1186                                  err.what() + ")")
1187                                     .str());
1188         }
1189       } else {
1190         throw py::value_error("Unexpected segment spec");
1191       }
1192     }
1193   }
1194 
1195   // Unpack operands.
1196   std::vector<PyValue *> operands;
1197   operands.reserve(operands.size());
1198   if (operandSegmentSpecObj.is_none()) {
1199     // Non-sized operand unpacking.
1200     for (auto it : llvm::enumerate(operandList)) {
1201       try {
1202         operands.push_back(py::cast<PyValue *>(it.value()));
1203         if (!operands.back())
1204           throw py::cast_error();
1205       } catch (py::cast_error &err) {
1206         throw py::value_error((llvm::Twine("Operand ") +
1207                                llvm::Twine(it.index()) + " of operation \"" +
1208                                name + "\" must be a Value (" + err.what() + ")")
1209                                   .str());
1210       }
1211     }
1212   } else {
1213     // Sized operand unpacking.
1214     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1215     if (operandSegmentSpec.size() != operandList.size()) {
1216       throw py::value_error((llvm::Twine("Operation \"") + name +
1217                              "\" requires " +
1218                              llvm::Twine(operandSegmentSpec.size()) +
1219                              "operand segments but was provided " +
1220                              llvm::Twine(operandList.size()))
1221                                 .str());
1222     }
1223     operandSegmentLengths.reserve(operandList.size());
1224     for (auto it :
1225          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1226       int segmentSpec = std::get<1>(it.value());
1227       if (segmentSpec == 1 || segmentSpec == 0) {
1228         // Unpack unary element.
1229         try {
1230           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1231           if (operandValue) {
1232             operands.push_back(operandValue);
1233             operandSegmentLengths.push_back(1);
1234           } else if (segmentSpec == 0) {
1235             // Allowed to be optional.
1236             operandSegmentLengths.push_back(0);
1237           } else {
1238             throw py::cast_error("was None and operand is not optional");
1239           }
1240         } catch (py::cast_error &err) {
1241           throw py::value_error((llvm::Twine("Operand ") +
1242                                  llvm::Twine(it.index()) + " of operation \"" +
1243                                  name + "\" must be a Value (" + err.what() +
1244                                  ")")
1245                                     .str());
1246         }
1247       } else if (segmentSpec == -1) {
1248         // Unpack sequence by appending.
1249         try {
1250           if (std::get<0>(it.value()).is_none()) {
1251             // Treat it as an empty list.
1252             operandSegmentLengths.push_back(0);
1253           } else {
1254             // Unpack the list.
1255             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1256             for (py::object segmentItem : segment) {
1257               operands.push_back(py::cast<PyValue *>(segmentItem));
1258               if (!operands.back()) {
1259                 throw py::cast_error("contained a None item");
1260               }
1261             }
1262             operandSegmentLengths.push_back(segment.size());
1263           }
1264         } catch (std::exception &err) {
1265           // NOTE: Sloppy to be using a catch-all here, but there are at least
1266           // three different unrelated exceptions that can be thrown in the
1267           // above "casts". Just keep the scope above small and catch them all.
1268           throw py::value_error((llvm::Twine("Operand ") +
1269                                  llvm::Twine(it.index()) + " of operation \"" +
1270                                  name + "\" must be a Sequence of Values (" +
1271                                  err.what() + ")")
1272                                     .str());
1273         }
1274       } else {
1275         throw py::value_error("Unexpected segment spec");
1276       }
1277     }
1278   }
1279 
1280   // Merge operand/result segment lengths into attributes if needed.
1281   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1282     // Dup.
1283     if (attributes) {
1284       attributes = py::dict(*attributes);
1285     } else {
1286       attributes = py::dict();
1287     }
1288     if (attributes->contains("result_segment_sizes") ||
1289         attributes->contains("operand_segment_sizes")) {
1290       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1291                             "'operand_segment_sizes' attribute is unsupported. "
1292                             "Use Operation.create for such low-level access.");
1293     }
1294 
1295     // Add result_segment_sizes attribute.
1296     if (!resultSegmentLengths.empty()) {
1297       int64_t size = resultSegmentLengths.size();
1298       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1299           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1300           resultSegmentLengths.size(), resultSegmentLengths.data());
1301       (*attributes)["result_segment_sizes"] =
1302           PyAttribute(context, segmentLengthAttr);
1303     }
1304 
1305     // Add operand_segment_sizes attribute.
1306     if (!operandSegmentLengths.empty()) {
1307       int64_t size = operandSegmentLengths.size();
1308       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1309           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1310           operandSegmentLengths.size(), operandSegmentLengths.data());
1311       (*attributes)["operand_segment_sizes"] =
1312           PyAttribute(context, segmentLengthAttr);
1313     }
1314   }
1315 
1316   // Delegate to create.
1317   return PyOperation::create(std::move(name),
1318                              /*results=*/std::move(resultTypes),
1319                              /*operands=*/std::move(operands),
1320                              /*attributes=*/std::move(attributes),
1321                              /*successors=*/std::move(successors),
1322                              /*regions=*/*regions, location, maybeIp);
1323 }
1324 
1325 PyOpView::PyOpView(py::object operationObject)
1326     // Casting through the PyOperationBase base-class and then back to the
1327     // Operation lets us accept any PyOperationBase subclass.
1328     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1329       operationObject(operation.getRef().getObject()) {}
1330 
1331 py::object PyOpView::createRawSubclass(py::object userClass) {
1332   // This is... a little gross. The typical pattern is to have a pure python
1333   // class that extends OpView like:
1334   //   class AddFOp(_cext.ir.OpView):
1335   //     def __init__(self, loc, lhs, rhs):
1336   //       operation = loc.context.create_operation(
1337   //           "addf", lhs, rhs, results=[lhs.type])
1338   //       super().__init__(operation)
1339   //
1340   // I.e. The goal of the user facing type is to provide a nice constructor
1341   // that has complete freedom for the op under construction. This is at odds
1342   // with our other desire to sometimes create this object by just passing an
1343   // operation (to initialize the base class). We could do *arg and **kwargs
1344   // munging to try to make it work, but instead, we synthesize a new class
1345   // on the fly which extends this user class (AddFOp in this example) and
1346   // *give it* the base class's __init__ method, thus bypassing the
1347   // intermediate subclass's __init__ method entirely. While slightly,
1348   // underhanded, this is safe/legal because the type hierarchy has not changed
1349   // (we just added a new leaf) and we aren't mucking around with __new__.
1350   // Typically, this new class will be stored on the original as "_Raw" and will
1351   // be used for casts and other things that need a variant of the class that
1352   // is initialized purely from an operation.
1353   py::object parentMetaclass =
1354       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1355   py::dict attributes;
1356   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1357   // now.
1358   //   auto opViewType = py::type::of<PyOpView>();
1359   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1360   attributes["__init__"] = opViewType.attr("__init__");
1361   py::str origName = userClass.attr("__name__");
1362   py::str newName = py::str("_") + origName;
1363   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1364 }
1365 
1366 //------------------------------------------------------------------------------
1367 // PyInsertionPoint.
1368 //------------------------------------------------------------------------------
1369 
1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1371 
1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1373     : refOperation(beforeOperationBase.getOperation().getRef()),
1374       block((*refOperation)->getBlock()) {}
1375 
1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1377   PyOperation &operation = operationBase.getOperation();
1378   if (operation.isAttached())
1379     throw SetPyError(PyExc_ValueError,
1380                      "Attempt to insert operation that is already attached");
1381   block.getParentOperation()->checkValid();
1382   MlirOperation beforeOp = {nullptr};
1383   if (refOperation) {
1384     // Insert before operation.
1385     (*refOperation)->checkValid();
1386     beforeOp = (*refOperation)->get();
1387   } else {
1388     // Insert at end (before null) is only valid if the block does not
1389     // already end in a known terminator (violating this will cause assertion
1390     // failures later).
1391     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1392       throw py::index_error("Cannot insert operation at the end of a block "
1393                             "that already has a terminator. Did you mean to "
1394                             "use 'InsertionPoint.at_block_terminator(block)' "
1395                             "versus 'InsertionPoint(block)'?");
1396     }
1397   }
1398   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1399   operation.setAttached();
1400 }
1401 
1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1403   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1404   if (mlirOperationIsNull(firstOp)) {
1405     // Just insert at end.
1406     return PyInsertionPoint(block);
1407   }
1408 
1409   // Insert before first op.
1410   PyOperationRef firstOpRef = PyOperation::forOperation(
1411       block.getParentOperation()->getContext(), firstOp);
1412   return PyInsertionPoint{block, std::move(firstOpRef)};
1413 }
1414 
1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1416   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1417   if (mlirOperationIsNull(terminator))
1418     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1419   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1420       block.getParentOperation()->getContext(), terminator);
1421   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1422 }
1423 
1424 py::object PyInsertionPoint::contextEnter() {
1425   return PyThreadContextEntry::pushInsertionPoint(*this);
1426 }
1427 
1428 void PyInsertionPoint::contextExit(pybind11::object excType,
1429                                    pybind11::object excVal,
1430                                    pybind11::object excTb) {
1431   PyThreadContextEntry::popInsertionPoint(*this);
1432 }
1433 
1434 //------------------------------------------------------------------------------
1435 // PyAttribute.
1436 //------------------------------------------------------------------------------
1437 
1438 bool PyAttribute::operator==(const PyAttribute &other) {
1439   return mlirAttributeEqual(attr, other.attr);
1440 }
1441 
1442 py::object PyAttribute::getCapsule() {
1443   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1444 }
1445 
1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1447   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1448   if (mlirAttributeIsNull(rawAttr))
1449     throw py::error_already_set();
1450   return PyAttribute(
1451       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1452 }
1453 
1454 //------------------------------------------------------------------------------
1455 // PyNamedAttribute.
1456 //------------------------------------------------------------------------------
1457 
1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1459     : ownedName(new std::string(std::move(ownedName))) {
1460   namedAttr = mlirNamedAttributeGet(
1461       mlirIdentifierGet(mlirAttributeGetContext(attr),
1462                         toMlirStringRef(*this->ownedName)),
1463       attr);
1464 }
1465 
1466 //------------------------------------------------------------------------------
1467 // PyType.
1468 //------------------------------------------------------------------------------
1469 
1470 bool PyType::operator==(const PyType &other) {
1471   return mlirTypeEqual(type, other.type);
1472 }
1473 
1474 py::object PyType::getCapsule() {
1475   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1476 }
1477 
1478 PyType PyType::createFromCapsule(py::object capsule) {
1479   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1480   if (mlirTypeIsNull(rawType))
1481     throw py::error_already_set();
1482   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1483                 rawType);
1484 }
1485 
1486 //------------------------------------------------------------------------------
1487 // PyValue and subclases.
1488 //------------------------------------------------------------------------------
1489 
1490 pybind11::object PyValue::getCapsule() {
1491   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1492 }
1493 
1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1495   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1496   if (mlirValueIsNull(value))
1497     throw py::error_already_set();
1498   MlirOperation owner;
1499   if (mlirValueIsAOpResult(value))
1500     owner = mlirOpResultGetOwner(value);
1501   if (mlirValueIsABlockArgument(value))
1502     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1503   if (mlirOperationIsNull(owner))
1504     throw py::error_already_set();
1505   MlirContext ctx = mlirOperationGetContext(owner);
1506   PyOperationRef ownerRef =
1507       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1508   return PyValue(ownerRef, value);
1509 }
1510 
1511 namespace {
1512 /// CRTP base class for Python MLIR values that subclass Value and should be
1513 /// castable from it. The value hierarchy is one level deep and is not supposed
1514 /// to accommodate other levels unless core MLIR changes.
1515 template <typename DerivedTy>
1516 class PyConcreteValue : public PyValue {
1517 public:
1518   // Derived classes must define statics for:
1519   //   IsAFunctionTy isaFunction
1520   //   const char *pyClassName
1521   // and redefine bindDerived.
1522   using ClassTy = py::class_<DerivedTy, PyValue>;
1523   using IsAFunctionTy = bool (*)(MlirValue);
1524 
1525   PyConcreteValue() = default;
1526   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1527       : PyValue(operationRef, value) {}
1528   PyConcreteValue(PyValue &orig)
1529       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1530 
1531   /// Attempts to cast the original value to the derived type and throws on
1532   /// type mismatches.
1533   static MlirValue castFrom(PyValue &orig) {
1534     if (!DerivedTy::isaFunction(orig.get())) {
1535       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1536       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1537                                              DerivedTy::pyClassName +
1538                                              " (from " + origRepr + ")");
1539     }
1540     return orig.get();
1541   }
1542 
1543   /// Binds the Python module objects to functions of this class.
1544   static void bind(py::module &m) {
1545     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1546     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1547     DerivedTy::bindDerived(cls);
1548   }
1549 
1550   /// Implemented by derived classes to add methods to the Python subclass.
1551   static void bindDerived(ClassTy &m) {}
1552 };
1553 
1554 /// Python wrapper for MlirBlockArgument.
1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1556 public:
1557   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1558   static constexpr const char *pyClassName = "BlockArgument";
1559   using PyConcreteValue::PyConcreteValue;
1560 
1561   static void bindDerived(ClassTy &c) {
1562     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1563       return PyBlock(self.getParentOperation(),
1564                      mlirBlockArgumentGetOwner(self.get()));
1565     });
1566     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1567       return mlirBlockArgumentGetArgNumber(self.get());
1568     });
1569     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1570       return mlirBlockArgumentSetType(self.get(), type);
1571     });
1572   }
1573 };
1574 
1575 /// Python wrapper for MlirOpResult.
1576 class PyOpResult : public PyConcreteValue<PyOpResult> {
1577 public:
1578   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1579   static constexpr const char *pyClassName = "OpResult";
1580   using PyConcreteValue::PyConcreteValue;
1581 
1582   static void bindDerived(ClassTy &c) {
1583     c.def_property_readonly("owner", [](PyOpResult &self) {
1584       assert(
1585           mlirOperationEqual(self.getParentOperation()->get(),
1586                              mlirOpResultGetOwner(self.get())) &&
1587           "expected the owner of the value in Python to match that in the IR");
1588       return self.getParentOperation().getObject();
1589     });
1590     c.def_property_readonly("result_number", [](PyOpResult &self) {
1591       return mlirOpResultGetResultNumber(self.get());
1592     });
1593   }
1594 };
1595 
1596 /// Returns the list of types of the values held by container.
1597 template <typename Container>
1598 static std::vector<PyType> getValueTypes(Container &container,
1599                                          PyMlirContextRef &context) {
1600   std::vector<PyType> result;
1601   result.reserve(container.getNumElements());
1602   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1603     result.push_back(
1604         PyType(context, mlirValueGetType(container.getElement(i).get())));
1605   }
1606   return result;
1607 }
1608 
1609 /// A list of block arguments. Internally, these are stored as consecutive
1610 /// elements, random access is cheap. The argument list is associated with the
1611 /// operation that contains the block (detached blocks are not allowed in
1612 /// Python bindings) and extends its lifetime.
1613 class PyBlockArgumentList
1614     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1615 public:
1616   static constexpr const char *pyClassName = "BlockArgumentList";
1617 
1618   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1619                       intptr_t startIndex = 0, intptr_t length = -1,
1620                       intptr_t step = 1)
1621       : Sliceable(startIndex,
1622                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1623                   step),
1624         operation(std::move(operation)), block(block) {}
1625 
1626   /// Returns the number of arguments in the list.
1627   intptr_t getNumElements() {
1628     operation->checkValid();
1629     return mlirBlockGetNumArguments(block);
1630   }
1631 
1632   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1633   PyBlockArgument getElement(intptr_t pos) {
1634     MlirValue argument = mlirBlockGetArgument(block, pos);
1635     return PyBlockArgument(operation, argument);
1636   }
1637 
1638   /// Returns a sublist of this list.
1639   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1640                             intptr_t step) {
1641     return PyBlockArgumentList(operation, block, startIndex, length, step);
1642   }
1643 
1644   static void bindDerived(ClassTy &c) {
1645     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1646       return getValueTypes(self, self.operation->getContext());
1647     });
1648   }
1649 
1650 private:
1651   PyOperationRef operation;
1652   MlirBlock block;
1653 };
1654 
1655 /// A list of operation operands. Internally, these are stored as consecutive
1656 /// elements, random access is cheap. The result list is associated with the
1657 /// operation whose results these are, and extends the lifetime of this
1658 /// operation.
1659 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1660 public:
1661   static constexpr const char *pyClassName = "OpOperandList";
1662 
1663   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1664                   intptr_t length = -1, intptr_t step = 1)
1665       : Sliceable(startIndex,
1666                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1667                                : length,
1668                   step),
1669         operation(operation) {}
1670 
1671   intptr_t getNumElements() {
1672     operation->checkValid();
1673     return mlirOperationGetNumOperands(operation->get());
1674   }
1675 
1676   PyValue getElement(intptr_t pos) {
1677     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1678     MlirOperation owner;
1679     if (mlirValueIsAOpResult(operand))
1680       owner = mlirOpResultGetOwner(operand);
1681     else if (mlirValueIsABlockArgument(operand))
1682       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1683     else
1684       assert(false && "Value must be an block arg or op result.");
1685     PyOperationRef pyOwner =
1686         PyOperation::forOperation(operation->getContext(), owner);
1687     return PyValue(pyOwner, operand);
1688   }
1689 
1690   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1691     return PyOpOperandList(operation, startIndex, length, step);
1692   }
1693 
1694   void dunderSetItem(intptr_t index, PyValue value) {
1695     index = wrapIndex(index);
1696     mlirOperationSetOperand(operation->get(), index, value.get());
1697   }
1698 
1699   static void bindDerived(ClassTy &c) {
1700     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1701   }
1702 
1703 private:
1704   PyOperationRef operation;
1705 };
1706 
1707 /// A list of operation results. Internally, these are stored as consecutive
1708 /// elements, random access is cheap. The result list is associated with the
1709 /// operation whose results these are, and extends the lifetime of this
1710 /// operation.
1711 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1712 public:
1713   static constexpr const char *pyClassName = "OpResultList";
1714 
1715   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1716                  intptr_t length = -1, intptr_t step = 1)
1717       : Sliceable(startIndex,
1718                   length == -1 ? mlirOperationGetNumResults(operation->get())
1719                                : length,
1720                   step),
1721         operation(operation) {}
1722 
1723   intptr_t getNumElements() {
1724     operation->checkValid();
1725     return mlirOperationGetNumResults(operation->get());
1726   }
1727 
1728   PyOpResult getElement(intptr_t index) {
1729     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1730     return PyOpResult(value);
1731   }
1732 
1733   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1734     return PyOpResultList(operation, startIndex, length, step);
1735   }
1736 
1737   static void bindDerived(ClassTy &c) {
1738     c.def_property_readonly("types", [](PyOpResultList &self) {
1739       return getValueTypes(self, self.operation->getContext());
1740     });
1741   }
1742 
1743 private:
1744   PyOperationRef operation;
1745 };
1746 
1747 /// A list of operation attributes. Can be indexed by name, producing
1748 /// attributes, or by index, producing named attributes.
1749 class PyOpAttributeMap {
1750 public:
1751   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1752 
1753   PyAttribute dunderGetItemNamed(const std::string &name) {
1754     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1755                                                          toMlirStringRef(name));
1756     if (mlirAttributeIsNull(attr)) {
1757       throw SetPyError(PyExc_KeyError,
1758                        "attempt to access a non-existent attribute");
1759     }
1760     return PyAttribute(operation->getContext(), attr);
1761   }
1762 
1763   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1764     if (index < 0 || index >= dunderLen()) {
1765       throw SetPyError(PyExc_IndexError,
1766                        "attempt to access out of bounds attribute");
1767     }
1768     MlirNamedAttribute namedAttr =
1769         mlirOperationGetAttribute(operation->get(), index);
1770     return PyNamedAttribute(
1771         namedAttr.attribute,
1772         std::string(mlirIdentifierStr(namedAttr.name).data));
1773   }
1774 
1775   void dunderSetItem(const std::string &name, PyAttribute attr) {
1776     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1777                                     attr);
1778   }
1779 
1780   void dunderDelItem(const std::string &name) {
1781     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1782                                                      toMlirStringRef(name));
1783     if (!removed)
1784       throw SetPyError(PyExc_KeyError,
1785                        "attempt to delete a non-existent attribute");
1786   }
1787 
1788   intptr_t dunderLen() {
1789     return mlirOperationGetNumAttributes(operation->get());
1790   }
1791 
1792   bool dunderContains(const std::string &name) {
1793     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1794         operation->get(), toMlirStringRef(name)));
1795   }
1796 
1797   static void bind(py::module &m) {
1798     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1799         .def("__contains__", &PyOpAttributeMap::dunderContains)
1800         .def("__len__", &PyOpAttributeMap::dunderLen)
1801         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1802         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1803         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1804         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1805   }
1806 
1807 private:
1808   PyOperationRef operation;
1809 };
1810 
1811 } // end namespace
1812 
1813 //------------------------------------------------------------------------------
1814 // Populates the core exports of the 'ir' submodule.
1815 //------------------------------------------------------------------------------
1816 
1817 void mlir::python::populateIRCore(py::module &m) {
1818   //----------------------------------------------------------------------------
1819   // Mapping of MlirContext.
1820   //----------------------------------------------------------------------------
1821   py::class_<PyMlirContext>(m, "Context", py::module_local())
1822       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1823       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1824       .def("_get_context_again",
1825            [](PyMlirContext &self) {
1826              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1827              return ref.releaseObject();
1828            })
1829       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1830       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1831       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1832                              &PyMlirContext::getCapsule)
1833       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1834       .def("__enter__", &PyMlirContext::contextEnter)
1835       .def("__exit__", &PyMlirContext::contextExit)
1836       .def_property_readonly_static(
1837           "current",
1838           [](py::object & /*class*/) {
1839             auto *context = PyThreadContextEntry::getDefaultContext();
1840             if (!context)
1841               throw SetPyError(PyExc_ValueError, "No current Context");
1842             return context;
1843           },
1844           "Gets the Context bound to the current thread or raises ValueError")
1845       .def_property_readonly(
1846           "dialects",
1847           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1848           "Gets a container for accessing dialects by name")
1849       .def_property_readonly(
1850           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1851           "Alias for 'dialect'")
1852       .def(
1853           "get_dialect_descriptor",
1854           [=](PyMlirContext &self, std::string &name) {
1855             MlirDialect dialect = mlirContextGetOrLoadDialect(
1856                 self.get(), {name.data(), name.size()});
1857             if (mlirDialectIsNull(dialect)) {
1858               throw SetPyError(PyExc_ValueError,
1859                                Twine("Dialect '") + name + "' not found");
1860             }
1861             return PyDialectDescriptor(self.getRef(), dialect);
1862           },
1863           "Gets or loads a dialect by name, returning its descriptor object")
1864       .def_property(
1865           "allow_unregistered_dialects",
1866           [](PyMlirContext &self) -> bool {
1867             return mlirContextGetAllowUnregisteredDialects(self.get());
1868           },
1869           [](PyMlirContext &self, bool value) {
1870             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1871           })
1872       .def("enable_multithreading",
1873            [](PyMlirContext &self, bool enable) {
1874              mlirContextEnableMultithreading(self.get(), enable);
1875            })
1876       .def("is_registered_operation",
1877            [](PyMlirContext &self, std::string &name) {
1878              return mlirContextIsRegisteredOperation(
1879                  self.get(), MlirStringRef{name.data(), name.size()});
1880            });
1881 
1882   //----------------------------------------------------------------------------
1883   // Mapping of PyDialectDescriptor
1884   //----------------------------------------------------------------------------
1885   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1886       .def_property_readonly("namespace",
1887                              [](PyDialectDescriptor &self) {
1888                                MlirStringRef ns =
1889                                    mlirDialectGetNamespace(self.get());
1890                                return py::str(ns.data, ns.length);
1891                              })
1892       .def("__repr__", [](PyDialectDescriptor &self) {
1893         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1894         std::string repr("<DialectDescriptor ");
1895         repr.append(ns.data, ns.length);
1896         repr.append(">");
1897         return repr;
1898       });
1899 
1900   //----------------------------------------------------------------------------
1901   // Mapping of PyDialects
1902   //----------------------------------------------------------------------------
1903   py::class_<PyDialects>(m, "Dialects", py::module_local())
1904       .def("__getitem__",
1905            [=](PyDialects &self, std::string keyName) {
1906              MlirDialect dialect =
1907                  self.getDialectForKey(keyName, /*attrError=*/false);
1908              py::object descriptor =
1909                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1910              return createCustomDialectWrapper(keyName, std::move(descriptor));
1911            })
1912       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1913         MlirDialect dialect =
1914             self.getDialectForKey(attrName, /*attrError=*/true);
1915         py::object descriptor =
1916             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1917         return createCustomDialectWrapper(attrName, std::move(descriptor));
1918       });
1919 
1920   //----------------------------------------------------------------------------
1921   // Mapping of PyDialect
1922   //----------------------------------------------------------------------------
1923   py::class_<PyDialect>(m, "Dialect", py::module_local())
1924       .def(py::init<py::object>(), "descriptor")
1925       .def_property_readonly(
1926           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1927       .def("__repr__", [](py::object self) {
1928         auto clazz = self.attr("__class__");
1929         return py::str("<Dialect ") +
1930                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1931                clazz.attr("__module__") + py::str(".") +
1932                clazz.attr("__name__") + py::str(")>");
1933       });
1934 
1935   //----------------------------------------------------------------------------
1936   // Mapping of Location
1937   //----------------------------------------------------------------------------
1938   py::class_<PyLocation>(m, "Location", py::module_local())
1939       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1940       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1941       .def("__enter__", &PyLocation::contextEnter)
1942       .def("__exit__", &PyLocation::contextExit)
1943       .def("__eq__",
1944            [](PyLocation &self, PyLocation &other) -> bool {
1945              return mlirLocationEqual(self, other);
1946            })
1947       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1948       .def_property_readonly_static(
1949           "current",
1950           [](py::object & /*class*/) {
1951             auto *loc = PyThreadContextEntry::getDefaultLocation();
1952             if (!loc)
1953               throw SetPyError(PyExc_ValueError, "No current Location");
1954             return loc;
1955           },
1956           "Gets the Location bound to the current thread or raises ValueError")
1957       .def_static(
1958           "unknown",
1959           [](DefaultingPyMlirContext context) {
1960             return PyLocation(context->getRef(),
1961                               mlirLocationUnknownGet(context->get()));
1962           },
1963           py::arg("context") = py::none(),
1964           "Gets a Location representing an unknown location")
1965       .def_static(
1966           "file",
1967           [](std::string filename, int line, int col,
1968              DefaultingPyMlirContext context) {
1969             return PyLocation(
1970                 context->getRef(),
1971                 mlirLocationFileLineColGet(
1972                     context->get(), toMlirStringRef(filename), line, col));
1973           },
1974           py::arg("filename"), py::arg("line"), py::arg("col"),
1975           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1976       .def_static(
1977           "name",
1978           [](std::string name, llvm::Optional<PyLocation> childLoc,
1979              DefaultingPyMlirContext context) {
1980             return PyLocation(
1981                 context->getRef(),
1982                 mlirLocationNameGet(
1983                     context->get(), toMlirStringRef(name),
1984                     childLoc ? childLoc->get()
1985                              : mlirLocationUnknownGet(context->get())));
1986           },
1987           py::arg("name"), py::arg("childLoc") = py::none(),
1988           py::arg("context") = py::none(), kContextGetNameLocationDocString)
1989       .def_property_readonly(
1990           "context",
1991           [](PyLocation &self) { return self.getContext().getObject(); },
1992           "Context that owns the Location")
1993       .def("__repr__", [](PyLocation &self) {
1994         PyPrintAccumulator printAccum;
1995         mlirLocationPrint(self, printAccum.getCallback(),
1996                           printAccum.getUserData());
1997         return printAccum.join();
1998       });
1999 
2000   //----------------------------------------------------------------------------
2001   // Mapping of Module
2002   //----------------------------------------------------------------------------
2003   py::class_<PyModule>(m, "Module", py::module_local())
2004       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2005       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2006       .def_static(
2007           "parse",
2008           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2009             MlirModule module = mlirModuleCreateParse(
2010                 context->get(), toMlirStringRef(moduleAsm));
2011             // TODO: Rework error reporting once diagnostic engine is exposed
2012             // in C API.
2013             if (mlirModuleIsNull(module)) {
2014               throw SetPyError(
2015                   PyExc_ValueError,
2016                   "Unable to parse module assembly (see diagnostics)");
2017             }
2018             return PyModule::forModule(module).releaseObject();
2019           },
2020           py::arg("asm"), py::arg("context") = py::none(),
2021           kModuleParseDocstring)
2022       .def_static(
2023           "create",
2024           [](DefaultingPyLocation loc) {
2025             MlirModule module = mlirModuleCreateEmpty(loc);
2026             return PyModule::forModule(module).releaseObject();
2027           },
2028           py::arg("loc") = py::none(), "Creates an empty module")
2029       .def_property_readonly(
2030           "context",
2031           [](PyModule &self) { return self.getContext().getObject(); },
2032           "Context that created the Module")
2033       .def_property_readonly(
2034           "operation",
2035           [](PyModule &self) {
2036             return PyOperation::forOperation(self.getContext(),
2037                                              mlirModuleGetOperation(self.get()),
2038                                              self.getRef().releaseObject())
2039                 .releaseObject();
2040           },
2041           "Accesses the module as an operation")
2042       .def_property_readonly(
2043           "body",
2044           [](PyModule &self) {
2045             PyOperationRef module_op = PyOperation::forOperation(
2046                 self.getContext(), mlirModuleGetOperation(self.get()),
2047                 self.getRef().releaseObject());
2048             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2049             return returnBlock;
2050           },
2051           "Return the block for this module")
2052       .def(
2053           "dump",
2054           [](PyModule &self) {
2055             mlirOperationDump(mlirModuleGetOperation(self.get()));
2056           },
2057           kDumpDocstring)
2058       .def(
2059           "__str__",
2060           [](PyModule &self) {
2061             MlirOperation operation = mlirModuleGetOperation(self.get());
2062             PyPrintAccumulator printAccum;
2063             mlirOperationPrint(operation, printAccum.getCallback(),
2064                                printAccum.getUserData());
2065             return printAccum.join();
2066           },
2067           kOperationStrDunderDocstring);
2068 
2069   //----------------------------------------------------------------------------
2070   // Mapping of Operation.
2071   //----------------------------------------------------------------------------
2072   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2073       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2074                              [](PyOperationBase &self) {
2075                                return self.getOperation().getCapsule();
2076                              })
2077       .def("__eq__",
2078            [](PyOperationBase &self, PyOperationBase &other) {
2079              return &self.getOperation() == &other.getOperation();
2080            })
2081       .def("__eq__",
2082            [](PyOperationBase &self, py::object other) { return false; })
2083       .def_property_readonly("attributes",
2084                              [](PyOperationBase &self) {
2085                                return PyOpAttributeMap(
2086                                    self.getOperation().getRef());
2087                              })
2088       .def_property_readonly("operands",
2089                              [](PyOperationBase &self) {
2090                                return PyOpOperandList(
2091                                    self.getOperation().getRef());
2092                              })
2093       .def_property_readonly("regions",
2094                              [](PyOperationBase &self) {
2095                                return PyRegionList(
2096                                    self.getOperation().getRef());
2097                              })
2098       .def_property_readonly(
2099           "results",
2100           [](PyOperationBase &self) {
2101             return PyOpResultList(self.getOperation().getRef());
2102           },
2103           "Returns the list of Operation results.")
2104       .def_property_readonly(
2105           "result",
2106           [](PyOperationBase &self) {
2107             auto &operation = self.getOperation();
2108             auto numResults = mlirOperationGetNumResults(operation);
2109             if (numResults != 1) {
2110               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2111               throw SetPyError(
2112                   PyExc_ValueError,
2113                   Twine("Cannot call .result on operation ") +
2114                       StringRef(name.data, name.length) + " which has " +
2115                       Twine(numResults) +
2116                       " results (it is only valid for operations with a "
2117                       "single result)");
2118             }
2119             return PyOpResult(operation.getRef(),
2120                               mlirOperationGetResult(operation, 0));
2121           },
2122           "Shortcut to get an op result if it has only one (throws an error "
2123           "otherwise).")
2124       .def("__iter__",
2125            [](PyOperationBase &self) {
2126              return PyRegionIterator(self.getOperation().getRef());
2127            })
2128       .def(
2129           "__str__",
2130           [](PyOperationBase &self) {
2131             return self.getAsm(/*binary=*/false,
2132                                /*largeElementsLimit=*/llvm::None,
2133                                /*enableDebugInfo=*/false,
2134                                /*prettyDebugInfo=*/false,
2135                                /*printGenericOpForm=*/false,
2136                                /*useLocalScope=*/false);
2137           },
2138           "Returns the assembly form of the operation.")
2139       .def("print", &PyOperationBase::print,
2140            // Careful: Lots of arguments must match up with print method.
2141            py::arg("file") = py::none(), py::arg("binary") = false,
2142            py::arg("large_elements_limit") = py::none(),
2143            py::arg("enable_debug_info") = false,
2144            py::arg("pretty_debug_info") = false,
2145            py::arg("print_generic_op_form") = false,
2146            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2147       .def("get_asm", &PyOperationBase::getAsm,
2148            // Careful: Lots of arguments must match up with get_asm method.
2149            py::arg("binary") = false,
2150            py::arg("large_elements_limit") = py::none(),
2151            py::arg("enable_debug_info") = false,
2152            py::arg("pretty_debug_info") = false,
2153            py::arg("print_generic_op_form") = false,
2154            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2155       .def(
2156           "verify",
2157           [](PyOperationBase &self) {
2158             return mlirOperationVerify(self.getOperation());
2159           },
2160           "Verify the operation and return true if it passes, false if it "
2161           "fails.");
2162 
2163   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2164       .def_static("create", &PyOperation::create, py::arg("name"),
2165                   py::arg("results") = py::none(),
2166                   py::arg("operands") = py::none(),
2167                   py::arg("attributes") = py::none(),
2168                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2169                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2170                   kOperationCreateDocstring)
2171       .def_property_readonly("parent",
2172                              [](PyOperation &self) -> py::object {
2173                                auto parent = self.getParentOperation();
2174                                if (parent)
2175                                  return parent->getObject();
2176                                return py::none();
2177                              })
2178       .def("erase", &PyOperation::erase)
2179       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2180                              &PyOperation::getCapsule)
2181       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2182       .def_property_readonly("name",
2183                              [](PyOperation &self) {
2184                                self.checkValid();
2185                                MlirOperation operation = self.get();
2186                                MlirStringRef name = mlirIdentifierStr(
2187                                    mlirOperationGetName(operation));
2188                                return py::str(name.data, name.length);
2189                              })
2190       .def_property_readonly(
2191           "context",
2192           [](PyOperation &self) {
2193             self.checkValid();
2194             return self.getContext().getObject();
2195           },
2196           "Context that owns the Operation")
2197       .def_property_readonly("opview", &PyOperation::createOpView);
2198 
2199   auto opViewClass =
2200       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2201           .def(py::init<py::object>())
2202           .def_property_readonly("operation", &PyOpView::getOperationObject)
2203           .def_property_readonly(
2204               "context",
2205               [](PyOpView &self) {
2206                 return self.getOperation().getContext().getObject();
2207               },
2208               "Context that owns the Operation")
2209           .def("__str__", [](PyOpView &self) {
2210             return py::str(self.getOperationObject());
2211           });
2212   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2213   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2214   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2215   opViewClass.attr("build_generic") = classmethod(
2216       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2217       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2218       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2219       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2220       "Builds a specific, generated OpView based on class level attributes.");
2221 
2222   //----------------------------------------------------------------------------
2223   // Mapping of PyRegion.
2224   //----------------------------------------------------------------------------
2225   py::class_<PyRegion>(m, "Region", py::module_local())
2226       .def_property_readonly(
2227           "blocks",
2228           [](PyRegion &self) {
2229             return PyBlockList(self.getParentOperation(), self.get());
2230           },
2231           "Returns a forward-optimized sequence of blocks.")
2232       .def(
2233           "__iter__",
2234           [](PyRegion &self) {
2235             self.checkValid();
2236             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2237             return PyBlockIterator(self.getParentOperation(), firstBlock);
2238           },
2239           "Iterates over blocks in the region.")
2240       .def("__eq__",
2241            [](PyRegion &self, PyRegion &other) {
2242              return self.get().ptr == other.get().ptr;
2243            })
2244       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2245 
2246   //----------------------------------------------------------------------------
2247   // Mapping of PyBlock.
2248   //----------------------------------------------------------------------------
2249   py::class_<PyBlock>(m, "Block", py::module_local())
2250       .def_property_readonly(
2251           "owner",
2252           [](PyBlock &self) {
2253             return self.getParentOperation()->createOpView();
2254           },
2255           "Returns the owning operation of this block.")
2256       .def_property_readonly(
2257           "region",
2258           [](PyBlock &self) {
2259             MlirRegion region = mlirBlockGetParentRegion(self.get());
2260             return PyRegion(self.getParentOperation(), region);
2261           },
2262           "Returns the owning region of this block.")
2263       .def_property_readonly(
2264           "arguments",
2265           [](PyBlock &self) {
2266             return PyBlockArgumentList(self.getParentOperation(), self.get());
2267           },
2268           "Returns a list of block arguments.")
2269       .def_property_readonly(
2270           "operations",
2271           [](PyBlock &self) {
2272             return PyOperationList(self.getParentOperation(), self.get());
2273           },
2274           "Returns a forward-optimized sequence of operations.")
2275       .def(
2276           "create_before",
2277           [](PyBlock &self, py::args pyArgTypes) {
2278             self.checkValid();
2279             llvm::SmallVector<MlirType, 4> argTypes;
2280             argTypes.reserve(pyArgTypes.size());
2281             for (auto &pyArg : pyArgTypes) {
2282               argTypes.push_back(pyArg.cast<PyType &>());
2283             }
2284 
2285             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2286             MlirRegion region = mlirBlockGetParentRegion(self.get());
2287             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2288             return PyBlock(self.getParentOperation(), block);
2289           },
2290           "Creates and returns a new Block before this block "
2291           "(with given argument types).")
2292       .def(
2293           "create_after",
2294           [](PyBlock &self, py::args pyArgTypes) {
2295             self.checkValid();
2296             llvm::SmallVector<MlirType, 4> argTypes;
2297             argTypes.reserve(pyArgTypes.size());
2298             for (auto &pyArg : pyArgTypes) {
2299               argTypes.push_back(pyArg.cast<PyType &>());
2300             }
2301 
2302             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2303             MlirRegion region = mlirBlockGetParentRegion(self.get());
2304             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2305             return PyBlock(self.getParentOperation(), block);
2306           },
2307           "Creates and returns a new Block after this block "
2308           "(with given argument types).")
2309       .def(
2310           "__iter__",
2311           [](PyBlock &self) {
2312             self.checkValid();
2313             MlirOperation firstOperation =
2314                 mlirBlockGetFirstOperation(self.get());
2315             return PyOperationIterator(self.getParentOperation(),
2316                                        firstOperation);
2317           },
2318           "Iterates over operations in the block.")
2319       .def("__eq__",
2320            [](PyBlock &self, PyBlock &other) {
2321              return self.get().ptr == other.get().ptr;
2322            })
2323       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2324       .def(
2325           "__str__",
2326           [](PyBlock &self) {
2327             self.checkValid();
2328             PyPrintAccumulator printAccum;
2329             mlirBlockPrint(self.get(), printAccum.getCallback(),
2330                            printAccum.getUserData());
2331             return printAccum.join();
2332           },
2333           "Returns the assembly form of the block.");
2334 
2335   //----------------------------------------------------------------------------
2336   // Mapping of PyInsertionPoint.
2337   //----------------------------------------------------------------------------
2338 
2339   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2340       .def(py::init<PyBlock &>(), py::arg("block"),
2341            "Inserts after the last operation but still inside the block.")
2342       .def("__enter__", &PyInsertionPoint::contextEnter)
2343       .def("__exit__", &PyInsertionPoint::contextExit)
2344       .def_property_readonly_static(
2345           "current",
2346           [](py::object & /*class*/) {
2347             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2348             if (!ip)
2349               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2350             return ip;
2351           },
2352           "Gets the InsertionPoint bound to the current thread or raises "
2353           "ValueError if none has been set")
2354       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2355            "Inserts before a referenced operation.")
2356       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2357                   py::arg("block"), "Inserts at the beginning of the block.")
2358       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2359                   py::arg("block"), "Inserts before the block terminator.")
2360       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2361            "Inserts an operation.")
2362       .def_property_readonly(
2363           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2364           "Returns the block that this InsertionPoint points to.");
2365 
2366   //----------------------------------------------------------------------------
2367   // Mapping of PyAttribute.
2368   //----------------------------------------------------------------------------
2369   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2370       // Delegate to the PyAttribute copy constructor, which will also lifetime
2371       // extend the backing context which owns the MlirAttribute.
2372       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2373            "Casts the passed attribute to the generic Attribute")
2374       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2375                              &PyAttribute::getCapsule)
2376       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2377       .def_static(
2378           "parse",
2379           [](std::string attrSpec, DefaultingPyMlirContext context) {
2380             MlirAttribute type = mlirAttributeParseGet(
2381                 context->get(), toMlirStringRef(attrSpec));
2382             // TODO: Rework error reporting once diagnostic engine is exposed
2383             // in C API.
2384             if (mlirAttributeIsNull(type)) {
2385               throw SetPyError(PyExc_ValueError,
2386                                Twine("Unable to parse attribute: '") +
2387                                    attrSpec + "'");
2388             }
2389             return PyAttribute(context->getRef(), type);
2390           },
2391           py::arg("asm"), py::arg("context") = py::none(),
2392           "Parses an attribute from an assembly form")
2393       .def_property_readonly(
2394           "context",
2395           [](PyAttribute &self) { return self.getContext().getObject(); },
2396           "Context that owns the Attribute")
2397       .def_property_readonly("type",
2398                              [](PyAttribute &self) {
2399                                return PyType(self.getContext()->getRef(),
2400                                              mlirAttributeGetType(self));
2401                              })
2402       .def(
2403           "get_named",
2404           [](PyAttribute &self, std::string name) {
2405             return PyNamedAttribute(self, std::move(name));
2406           },
2407           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2408       .def("__eq__",
2409            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2410       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2411       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2412       .def(
2413           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2414           kDumpDocstring)
2415       .def(
2416           "__str__",
2417           [](PyAttribute &self) {
2418             PyPrintAccumulator printAccum;
2419             mlirAttributePrint(self, printAccum.getCallback(),
2420                                printAccum.getUserData());
2421             return printAccum.join();
2422           },
2423           "Returns the assembly form of the Attribute.")
2424       .def("__repr__", [](PyAttribute &self) {
2425         // Generally, assembly formats are not printed for __repr__ because
2426         // this can cause exceptionally long debug output and exceptions.
2427         // However, attribute values are generally considered useful and are
2428         // printed. This may need to be re-evaluated if debug dumps end up
2429         // being excessive.
2430         PyPrintAccumulator printAccum;
2431         printAccum.parts.append("Attribute(");
2432         mlirAttributePrint(self, printAccum.getCallback(),
2433                            printAccum.getUserData());
2434         printAccum.parts.append(")");
2435         return printAccum.join();
2436       });
2437 
2438   //----------------------------------------------------------------------------
2439   // Mapping of PyNamedAttribute
2440   //----------------------------------------------------------------------------
2441   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2442       .def("__repr__",
2443            [](PyNamedAttribute &self) {
2444              PyPrintAccumulator printAccum;
2445              printAccum.parts.append("NamedAttribute(");
2446              printAccum.parts.append(
2447                  mlirIdentifierStr(self.namedAttr.name).data);
2448              printAccum.parts.append("=");
2449              mlirAttributePrint(self.namedAttr.attribute,
2450                                 printAccum.getCallback(),
2451                                 printAccum.getUserData());
2452              printAccum.parts.append(")");
2453              return printAccum.join();
2454            })
2455       .def_property_readonly(
2456           "name",
2457           [](PyNamedAttribute &self) {
2458             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2459                            mlirIdentifierStr(self.namedAttr.name).length);
2460           },
2461           "The name of the NamedAttribute binding")
2462       .def_property_readonly(
2463           "attr",
2464           [](PyNamedAttribute &self) {
2465             // TODO: When named attribute is removed/refactored, also remove
2466             // this constructor (it does an inefficient table lookup).
2467             auto contextRef = PyMlirContext::forContext(
2468                 mlirAttributeGetContext(self.namedAttr.attribute));
2469             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2470           },
2471           py::keep_alive<0, 1>(),
2472           "The underlying generic attribute of the NamedAttribute binding");
2473 
2474   //----------------------------------------------------------------------------
2475   // Mapping of PyType.
2476   //----------------------------------------------------------------------------
2477   py::class_<PyType>(m, "Type", py::module_local())
2478       // Delegate to the PyType copy constructor, which will also lifetime
2479       // extend the backing context which owns the MlirType.
2480       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2481            "Casts the passed type to the generic Type")
2482       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2483       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2484       .def_static(
2485           "parse",
2486           [](std::string typeSpec, DefaultingPyMlirContext context) {
2487             MlirType type =
2488                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2489             // TODO: Rework error reporting once diagnostic engine is exposed
2490             // in C API.
2491             if (mlirTypeIsNull(type)) {
2492               throw SetPyError(PyExc_ValueError,
2493                                Twine("Unable to parse type: '") + typeSpec +
2494                                    "'");
2495             }
2496             return PyType(context->getRef(), type);
2497           },
2498           py::arg("asm"), py::arg("context") = py::none(),
2499           kContextParseTypeDocstring)
2500       .def_property_readonly(
2501           "context", [](PyType &self) { return self.getContext().getObject(); },
2502           "Context that owns the Type")
2503       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2504       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2505       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2506       .def(
2507           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2508       .def(
2509           "__str__",
2510           [](PyType &self) {
2511             PyPrintAccumulator printAccum;
2512             mlirTypePrint(self, printAccum.getCallback(),
2513                           printAccum.getUserData());
2514             return printAccum.join();
2515           },
2516           "Returns the assembly form of the type.")
2517       .def("__repr__", [](PyType &self) {
2518         // Generally, assembly formats are not printed for __repr__ because
2519         // this can cause exceptionally long debug output and exceptions.
2520         // However, types are an exception as they typically have compact
2521         // assembly forms and printing them is useful.
2522         PyPrintAccumulator printAccum;
2523         printAccum.parts.append("Type(");
2524         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2525         printAccum.parts.append(")");
2526         return printAccum.join();
2527       });
2528 
2529   //----------------------------------------------------------------------------
2530   // Mapping of Value.
2531   //----------------------------------------------------------------------------
2532   py::class_<PyValue>(m, "Value", py::module_local())
2533       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2534       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2535       .def_property_readonly(
2536           "context",
2537           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2538           "Context in which the value lives.")
2539       .def(
2540           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2541           kDumpDocstring)
2542       .def_property_readonly(
2543           "owner",
2544           [](PyValue &self) {
2545             assert(mlirOperationEqual(self.getParentOperation()->get(),
2546                                       mlirOpResultGetOwner(self.get())) &&
2547                    "expected the owner of the value in Python to match that in "
2548                    "the IR");
2549             return self.getParentOperation().getObject();
2550           })
2551       .def("__eq__",
2552            [](PyValue &self, PyValue &other) {
2553              return self.get().ptr == other.get().ptr;
2554            })
2555       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2556       .def(
2557           "__str__",
2558           [](PyValue &self) {
2559             PyPrintAccumulator printAccum;
2560             printAccum.parts.append("Value(");
2561             mlirValuePrint(self.get(), printAccum.getCallback(),
2562                            printAccum.getUserData());
2563             printAccum.parts.append(")");
2564             return printAccum.join();
2565           },
2566           kValueDunderStrDocstring)
2567       .def_property_readonly("type", [](PyValue &self) {
2568         return PyType(self.getParentOperation()->getContext(),
2569                       mlirValueGetType(self.get()));
2570       });
2571   PyBlockArgument::bind(m);
2572   PyOpResult::bind(m);
2573 
2574   // Container bindings.
2575   PyBlockArgumentList::bind(m);
2576   PyBlockIterator::bind(m);
2577   PyBlockList::bind(m);
2578   PyOperationIterator::bind(m);
2579   PyOperationList::bind(m);
2580   PyOpAttributeMap::bind(m);
2581   PyOpOperandList::bind(m);
2582   PyOpResultList::bind(m);
2583   PyRegionIterator::bind(m);
2584   PyRegionList::bind(m);
2585 
2586   // Debug bindings.
2587   PyGlobalDebugFlag::bind(m);
2588 }
2589