1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
183   PyRegionIterator &dunderIter() { return *this; }
184 
185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator")
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence")
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
241   PyBlockIterator &dunderIter() { return *this; }
242 
243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator")
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList")
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
337   PyOperationIterator &dunderIter() { return *this; }
338 
339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator")
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList")
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   // If the "std" dialect was asked for, substitute the empty namespace :(
647   static const std::string emptyKey;
648   const std::string *canonKey = key == "std" ? &emptyKey : &key;
649   MlirDialect dialect = mlirContextGetOrLoadDialect(
650       getContext()->get(), {canonKey->data(), canonKey->size()});
651   if (mlirDialectIsNull(dialect)) {
652     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
653                      Twine("Dialect '") + key + "' not found");
654   }
655   return dialect;
656 }
657 
658 //------------------------------------------------------------------------------
659 // PyLocation
660 //------------------------------------------------------------------------------
661 
662 py::object PyLocation::getCapsule() {
663   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
664 }
665 
666 PyLocation PyLocation::createFromCapsule(py::object capsule) {
667   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
668   if (mlirLocationIsNull(rawLoc))
669     throw py::error_already_set();
670   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
671                     rawLoc);
672 }
673 
674 py::object PyLocation::contextEnter() {
675   return PyThreadContextEntry::pushLocation(*this);
676 }
677 
678 void PyLocation::contextExit(py::object excType, py::object excVal,
679                              py::object excTb) {
680   PyThreadContextEntry::popLocation(*this);
681 }
682 
683 PyLocation &DefaultingPyLocation::resolve() {
684   auto *location = PyThreadContextEntry::getDefaultLocation();
685   if (!location) {
686     throw SetPyError(
687         PyExc_RuntimeError,
688         "An MLIR function requires a Location but none was provided in the "
689         "call or from the surrounding environment. Either pass to the function "
690         "with a 'loc=' argument or establish a default using 'with loc:'");
691   }
692   return *location;
693 }
694 
695 //------------------------------------------------------------------------------
696 // PyModule
697 //------------------------------------------------------------------------------
698 
699 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
700     : BaseContextObject(std::move(contextRef)), module(module) {}
701 
702 PyModule::~PyModule() {
703   py::gil_scoped_acquire acquire;
704   auto &liveModules = getContext()->liveModules;
705   assert(liveModules.count(module.ptr) == 1 &&
706          "destroying module not in live map");
707   liveModules.erase(module.ptr);
708   mlirModuleDestroy(module);
709 }
710 
711 PyModuleRef PyModule::forModule(MlirModule module) {
712   MlirContext context = mlirModuleGetContext(module);
713   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
714 
715   py::gil_scoped_acquire acquire;
716   auto &liveModules = contextRef->liveModules;
717   auto it = liveModules.find(module.ptr);
718   if (it == liveModules.end()) {
719     // Create.
720     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
721     // Note that the default return value policy on cast is automatic_reference,
722     // which does not take ownership (delete will not be called).
723     // Just be explicit.
724     py::object pyRef =
725         py::cast(unownedModule, py::return_value_policy::take_ownership);
726     unownedModule->handle = pyRef;
727     liveModules[module.ptr] =
728         std::make_pair(unownedModule->handle, unownedModule);
729     return PyModuleRef(unownedModule, std::move(pyRef));
730   }
731   // Use existing.
732   PyModule *existing = it->second.second;
733   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
734   return PyModuleRef(existing, std::move(pyRef));
735 }
736 
737 py::object PyModule::createFromCapsule(py::object capsule) {
738   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
739   if (mlirModuleIsNull(rawModule))
740     throw py::error_already_set();
741   return forModule(rawModule).releaseObject();
742 }
743 
744 py::object PyModule::getCapsule() {
745   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
746 }
747 
748 //------------------------------------------------------------------------------
749 // PyOperation
750 //------------------------------------------------------------------------------
751 
752 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
753     : BaseContextObject(std::move(contextRef)), operation(operation) {}
754 
755 PyOperation::~PyOperation() {
756   // If the operation has already been invalidated there is nothing to do.
757   if (!valid)
758     return;
759   auto &liveOperations = getContext()->liveOperations;
760   assert(liveOperations.count(operation.ptr) == 1 &&
761          "destroying operation not in live map");
762   liveOperations.erase(operation.ptr);
763   if (!isAttached()) {
764     mlirOperationDestroy(operation);
765   }
766 }
767 
768 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
769                                            MlirOperation operation,
770                                            py::object parentKeepAlive) {
771   auto &liveOperations = contextRef->liveOperations;
772   // Create.
773   PyOperation *unownedOperation =
774       new PyOperation(std::move(contextRef), operation);
775   // Note that the default return value policy on cast is automatic_reference,
776   // which does not take ownership (delete will not be called).
777   // Just be explicit.
778   py::object pyRef =
779       py::cast(unownedOperation, py::return_value_policy::take_ownership);
780   unownedOperation->handle = pyRef;
781   if (parentKeepAlive) {
782     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
783   }
784   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
785   return PyOperationRef(unownedOperation, std::move(pyRef));
786 }
787 
788 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
789                                          MlirOperation operation,
790                                          py::object parentKeepAlive) {
791   auto &liveOperations = contextRef->liveOperations;
792   auto it = liveOperations.find(operation.ptr);
793   if (it == liveOperations.end()) {
794     // Create.
795     return createInstance(std::move(contextRef), operation,
796                           std::move(parentKeepAlive));
797   }
798   // Use existing.
799   PyOperation *existing = it->second.second;
800   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
801   return PyOperationRef(existing, std::move(pyRef));
802 }
803 
804 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
805                                            MlirOperation operation,
806                                            py::object parentKeepAlive) {
807   auto &liveOperations = contextRef->liveOperations;
808   assert(liveOperations.count(operation.ptr) == 0 &&
809          "cannot create detached operation that already exists");
810   (void)liveOperations;
811 
812   PyOperationRef created = createInstance(std::move(contextRef), operation,
813                                           std::move(parentKeepAlive));
814   created->attached = false;
815   return created;
816 }
817 
818 void PyOperation::checkValid() const {
819   if (!valid) {
820     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
821   }
822 }
823 
824 void PyOperationBase::print(py::object fileObject, bool binary,
825                             llvm::Optional<int64_t> largeElementsLimit,
826                             bool enableDebugInfo, bool prettyDebugInfo,
827                             bool printGenericOpForm, bool useLocalScope) {
828   PyOperation &operation = getOperation();
829   operation.checkValid();
830   if (fileObject.is_none())
831     fileObject = py::module::import("sys").attr("stdout");
832 
833   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
834     fileObject.attr("write")("// Verification failed, printing generic form\n");
835     printGenericOpForm = true;
836   }
837 
838   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
839   if (largeElementsLimit)
840     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
841   if (enableDebugInfo)
842     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
843   if (printGenericOpForm)
844     mlirOpPrintingFlagsPrintGenericOpForm(flags);
845 
846   PyFileAccumulator accum(fileObject, binary);
847   py::gil_scoped_release();
848   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
849                               accum.getUserData());
850   mlirOpPrintingFlagsDestroy(flags);
851 }
852 
853 py::object PyOperationBase::getAsm(bool binary,
854                                    llvm::Optional<int64_t> largeElementsLimit,
855                                    bool enableDebugInfo, bool prettyDebugInfo,
856                                    bool printGenericOpForm,
857                                    bool useLocalScope) {
858   py::object fileObject;
859   if (binary) {
860     fileObject = py::module::import("io").attr("BytesIO")();
861   } else {
862     fileObject = py::module::import("io").attr("StringIO")();
863   }
864   print(fileObject, /*binary=*/binary,
865         /*largeElementsLimit=*/largeElementsLimit,
866         /*enableDebugInfo=*/enableDebugInfo,
867         /*prettyDebugInfo=*/prettyDebugInfo,
868         /*printGenericOpForm=*/printGenericOpForm,
869         /*useLocalScope=*/useLocalScope);
870 
871   return fileObject.attr("getvalue")();
872 }
873 
874 PyOperationRef PyOperation::getParentOperation() {
875   checkValid();
876   if (!isAttached())
877     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
878   MlirOperation operation = mlirOperationGetParentOperation(get());
879   if (mlirOperationIsNull(operation))
880     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
881   return PyOperation::forOperation(getContext(), operation);
882 }
883 
884 PyBlock PyOperation::getBlock() {
885   checkValid();
886   PyOperationRef parentOperation = getParentOperation();
887   MlirBlock block = mlirOperationGetBlock(get());
888   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
889   return PyBlock{std::move(parentOperation), block};
890 }
891 
892 py::object PyOperation::getCapsule() {
893   checkValid();
894   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
895 }
896 
897 py::object PyOperation::createFromCapsule(py::object capsule) {
898   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
899   if (mlirOperationIsNull(rawOperation))
900     throw py::error_already_set();
901   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
902   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
903       .releaseObject();
904 }
905 
906 py::object PyOperation::create(
907     std::string name, llvm::Optional<std::vector<PyType *>> results,
908     llvm::Optional<std::vector<PyValue *>> operands,
909     llvm::Optional<py::dict> attributes,
910     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
911     DefaultingPyLocation location, py::object maybeIp) {
912   llvm::SmallVector<MlirValue, 4> mlirOperands;
913   llvm::SmallVector<MlirType, 4> mlirResults;
914   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
915   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
916 
917   // General parameter validation.
918   if (regions < 0)
919     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
920 
921   // Unpack/validate operands.
922   if (operands) {
923     mlirOperands.reserve(operands->size());
924     for (PyValue *operand : *operands) {
925       if (!operand)
926         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
927       mlirOperands.push_back(operand->get());
928     }
929   }
930 
931   // Unpack/validate results.
932   if (results) {
933     mlirResults.reserve(results->size());
934     for (PyType *result : *results) {
935       // TODO: Verify result type originate from the same context.
936       if (!result)
937         throw SetPyError(PyExc_ValueError, "result type cannot be None");
938       mlirResults.push_back(*result);
939     }
940   }
941   // Unpack/validate attributes.
942   if (attributes) {
943     mlirAttributes.reserve(attributes->size());
944     for (auto &it : *attributes) {
945       std::string key;
946       try {
947         key = it.first.cast<std::string>();
948       } catch (py::cast_error &err) {
949         std::string msg = "Invalid attribute key (not a string) when "
950                           "attempting to create the operation \"" +
951                           name + "\" (" + err.what() + ")";
952         throw py::cast_error(msg);
953       }
954       try {
955         auto &attribute = it.second.cast<PyAttribute &>();
956         // TODO: Verify attribute originates from the same context.
957         mlirAttributes.emplace_back(std::move(key), attribute);
958       } catch (py::reference_cast_error &) {
959         // This exception seems thrown when the value is "None".
960         std::string msg =
961             "Found an invalid (`None`?) attribute value for the key \"" + key +
962             "\" when attempting to create the operation \"" + name + "\"";
963         throw py::cast_error(msg);
964       } catch (py::cast_error &err) {
965         std::string msg = "Invalid attribute value for the key \"" + key +
966                           "\" when attempting to create the operation \"" +
967                           name + "\" (" + err.what() + ")";
968         throw py::cast_error(msg);
969       }
970     }
971   }
972   // Unpack/validate successors.
973   if (successors) {
974     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
975     mlirSuccessors.reserve(successors->size());
976     for (auto *successor : *successors) {
977       // TODO: Verify successor originate from the same context.
978       if (!successor)
979         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
980       mlirSuccessors.push_back(successor->get());
981     }
982   }
983 
984   // Apply unpacked/validated to the operation state. Beyond this
985   // point, exceptions cannot be thrown or else the state will leak.
986   MlirOperationState state =
987       mlirOperationStateGet(toMlirStringRef(name), location);
988   if (!mlirOperands.empty())
989     mlirOperationStateAddOperands(&state, mlirOperands.size(),
990                                   mlirOperands.data());
991   if (!mlirResults.empty())
992     mlirOperationStateAddResults(&state, mlirResults.size(),
993                                  mlirResults.data());
994   if (!mlirAttributes.empty()) {
995     // Note that the attribute names directly reference bytes in
996     // mlirAttributes, so that vector must not be changed from here
997     // on.
998     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
999     mlirNamedAttributes.reserve(mlirAttributes.size());
1000     for (auto &it : mlirAttributes)
1001       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1002           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1003                             toMlirStringRef(it.first)),
1004           it.second));
1005     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1006                                     mlirNamedAttributes.data());
1007   }
1008   if (!mlirSuccessors.empty())
1009     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1010                                     mlirSuccessors.data());
1011   if (regions) {
1012     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1013     mlirRegions.resize(regions);
1014     for (int i = 0; i < regions; ++i)
1015       mlirRegions[i] = mlirRegionCreate();
1016     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1017                                       mlirRegions.data());
1018   }
1019 
1020   // Construct the operation.
1021   MlirOperation operation = mlirOperationCreate(&state);
1022   PyOperationRef created =
1023       PyOperation::createDetached(location->getContext(), operation);
1024 
1025   // InsertPoint active?
1026   if (!maybeIp.is(py::cast(false))) {
1027     PyInsertionPoint *ip;
1028     if (maybeIp.is_none()) {
1029       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1030     } else {
1031       ip = py::cast<PyInsertionPoint *>(maybeIp);
1032     }
1033     if (ip)
1034       ip->insert(*created.get());
1035   }
1036 
1037   return created->createOpView();
1038 }
1039 
1040 py::object PyOperation::createOpView() {
1041   checkValid();
1042   MlirIdentifier ident = mlirOperationGetName(get());
1043   MlirStringRef identStr = mlirIdentifierStr(ident);
1044   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1045       StringRef(identStr.data, identStr.length));
1046   if (opViewClass)
1047     return (*opViewClass)(getRef().getObject());
1048   return py::cast(PyOpView(getRef().getObject()));
1049 }
1050 
1051 void PyOperation::erase() {
1052   checkValid();
1053   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1054   // Python reference to a child operation is live. All children should also
1055   // have their `valid` bit set to false.
1056   auto &liveOperations = getContext()->liveOperations;
1057   if (liveOperations.count(operation.ptr))
1058     liveOperations.erase(operation.ptr);
1059   mlirOperationDestroy(operation);
1060   valid = false;
1061 }
1062 
1063 //------------------------------------------------------------------------------
1064 // PyOpView
1065 //------------------------------------------------------------------------------
1066 
1067 py::object
1068 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1069                        py::list operandList,
1070                        llvm::Optional<py::dict> attributes,
1071                        llvm::Optional<std::vector<PyBlock *>> successors,
1072                        llvm::Optional<int> regions,
1073                        DefaultingPyLocation location, py::object maybeIp) {
1074   PyMlirContextRef context = location->getContext();
1075   // Class level operation construction metadata.
1076   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1077   // Operand and result segment specs are either none, which does no
1078   // variadic unpacking, or a list of ints with segment sizes, where each
1079   // element is either a positive number (typically 1 for a scalar) or -1 to
1080   // indicate that it is derived from the length of the same-indexed operand
1081   // or result (implying that it is a list at that position).
1082   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1083   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1084 
1085   std::vector<uint32_t> operandSegmentLengths;
1086   std::vector<uint32_t> resultSegmentLengths;
1087 
1088   // Validate/determine region count.
1089   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1090   int opMinRegionCount = std::get<0>(opRegionSpec);
1091   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1092   if (!regions) {
1093     regions = opMinRegionCount;
1094   }
1095   if (*regions < opMinRegionCount) {
1096     throw py::value_error(
1097         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1098          llvm::Twine(opMinRegionCount) +
1099          " regions but was built with regions=" + llvm::Twine(*regions))
1100             .str());
1101   }
1102   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1103     throw py::value_error(
1104         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1105          llvm::Twine(opMinRegionCount) +
1106          " regions but was built with regions=" + llvm::Twine(*regions))
1107             .str());
1108   }
1109 
1110   // Unpack results.
1111   std::vector<PyType *> resultTypes;
1112   resultTypes.reserve(resultTypeList.size());
1113   if (resultSegmentSpecObj.is_none()) {
1114     // Non-variadic result unpacking.
1115     for (auto it : llvm::enumerate(resultTypeList)) {
1116       try {
1117         resultTypes.push_back(py::cast<PyType *>(it.value()));
1118         if (!resultTypes.back())
1119           throw py::cast_error();
1120       } catch (py::cast_error &err) {
1121         throw py::value_error((llvm::Twine("Result ") +
1122                                llvm::Twine(it.index()) + " of operation \"" +
1123                                name + "\" must be a Type (" + err.what() + ")")
1124                                   .str());
1125       }
1126     }
1127   } else {
1128     // Sized result unpacking.
1129     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1130     if (resultSegmentSpec.size() != resultTypeList.size()) {
1131       throw py::value_error((llvm::Twine("Operation \"") + name +
1132                              "\" requires " +
1133                              llvm::Twine(resultSegmentSpec.size()) +
1134                              "result segments but was provided " +
1135                              llvm::Twine(resultTypeList.size()))
1136                                 .str());
1137     }
1138     resultSegmentLengths.reserve(resultTypeList.size());
1139     for (auto it :
1140          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1141       int segmentSpec = std::get<1>(it.value());
1142       if (segmentSpec == 1 || segmentSpec == 0) {
1143         // Unpack unary element.
1144         try {
1145           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1146           if (resultType) {
1147             resultTypes.push_back(resultType);
1148             resultSegmentLengths.push_back(1);
1149           } else if (segmentSpec == 0) {
1150             // Allowed to be optional.
1151             resultSegmentLengths.push_back(0);
1152           } else {
1153             throw py::cast_error("was None and result is not optional");
1154           }
1155         } catch (py::cast_error &err) {
1156           throw py::value_error((llvm::Twine("Result ") +
1157                                  llvm::Twine(it.index()) + " of operation \"" +
1158                                  name + "\" must be a Type (" + err.what() +
1159                                  ")")
1160                                     .str());
1161         }
1162       } else if (segmentSpec == -1) {
1163         // Unpack sequence by appending.
1164         try {
1165           if (std::get<0>(it.value()).is_none()) {
1166             // Treat it as an empty list.
1167             resultSegmentLengths.push_back(0);
1168           } else {
1169             // Unpack the list.
1170             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1171             for (py::object segmentItem : segment) {
1172               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1173               if (!resultTypes.back()) {
1174                 throw py::cast_error("contained a None item");
1175               }
1176             }
1177             resultSegmentLengths.push_back(segment.size());
1178           }
1179         } catch (std::exception &err) {
1180           // NOTE: Sloppy to be using a catch-all here, but there are at least
1181           // three different unrelated exceptions that can be thrown in the
1182           // above "casts". Just keep the scope above small and catch them all.
1183           throw py::value_error((llvm::Twine("Result ") +
1184                                  llvm::Twine(it.index()) + " of operation \"" +
1185                                  name + "\" must be a Sequence of Types (" +
1186                                  err.what() + ")")
1187                                     .str());
1188         }
1189       } else {
1190         throw py::value_error("Unexpected segment spec");
1191       }
1192     }
1193   }
1194 
1195   // Unpack operands.
1196   std::vector<PyValue *> operands;
1197   operands.reserve(operands.size());
1198   if (operandSegmentSpecObj.is_none()) {
1199     // Non-sized operand unpacking.
1200     for (auto it : llvm::enumerate(operandList)) {
1201       try {
1202         operands.push_back(py::cast<PyValue *>(it.value()));
1203         if (!operands.back())
1204           throw py::cast_error();
1205       } catch (py::cast_error &err) {
1206         throw py::value_error((llvm::Twine("Operand ") +
1207                                llvm::Twine(it.index()) + " of operation \"" +
1208                                name + "\" must be a Value (" + err.what() + ")")
1209                                   .str());
1210       }
1211     }
1212   } else {
1213     // Sized operand unpacking.
1214     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1215     if (operandSegmentSpec.size() != operandList.size()) {
1216       throw py::value_error((llvm::Twine("Operation \"") + name +
1217                              "\" requires " +
1218                              llvm::Twine(operandSegmentSpec.size()) +
1219                              "operand segments but was provided " +
1220                              llvm::Twine(operandList.size()))
1221                                 .str());
1222     }
1223     operandSegmentLengths.reserve(operandList.size());
1224     for (auto it :
1225          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1226       int segmentSpec = std::get<1>(it.value());
1227       if (segmentSpec == 1 || segmentSpec == 0) {
1228         // Unpack unary element.
1229         try {
1230           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1231           if (operandValue) {
1232             operands.push_back(operandValue);
1233             operandSegmentLengths.push_back(1);
1234           } else if (segmentSpec == 0) {
1235             // Allowed to be optional.
1236             operandSegmentLengths.push_back(0);
1237           } else {
1238             throw py::cast_error("was None and operand is not optional");
1239           }
1240         } catch (py::cast_error &err) {
1241           throw py::value_error((llvm::Twine("Operand ") +
1242                                  llvm::Twine(it.index()) + " of operation \"" +
1243                                  name + "\" must be a Value (" + err.what() +
1244                                  ")")
1245                                     .str());
1246         }
1247       } else if (segmentSpec == -1) {
1248         // Unpack sequence by appending.
1249         try {
1250           if (std::get<0>(it.value()).is_none()) {
1251             // Treat it as an empty list.
1252             operandSegmentLengths.push_back(0);
1253           } else {
1254             // Unpack the list.
1255             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1256             for (py::object segmentItem : segment) {
1257               operands.push_back(py::cast<PyValue *>(segmentItem));
1258               if (!operands.back()) {
1259                 throw py::cast_error("contained a None item");
1260               }
1261             }
1262             operandSegmentLengths.push_back(segment.size());
1263           }
1264         } catch (std::exception &err) {
1265           // NOTE: Sloppy to be using a catch-all here, but there are at least
1266           // three different unrelated exceptions that can be thrown in the
1267           // above "casts". Just keep the scope above small and catch them all.
1268           throw py::value_error((llvm::Twine("Operand ") +
1269                                  llvm::Twine(it.index()) + " of operation \"" +
1270                                  name + "\" must be a Sequence of Values (" +
1271                                  err.what() + ")")
1272                                     .str());
1273         }
1274       } else {
1275         throw py::value_error("Unexpected segment spec");
1276       }
1277     }
1278   }
1279 
1280   // Merge operand/result segment lengths into attributes if needed.
1281   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1282     // Dup.
1283     if (attributes) {
1284       attributes = py::dict(*attributes);
1285     } else {
1286       attributes = py::dict();
1287     }
1288     if (attributes->contains("result_segment_sizes") ||
1289         attributes->contains("operand_segment_sizes")) {
1290       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1291                             "'operand_segment_sizes' attribute is unsupported. "
1292                             "Use Operation.create for such low-level access.");
1293     }
1294 
1295     // Add result_segment_sizes attribute.
1296     if (!resultSegmentLengths.empty()) {
1297       int64_t size = resultSegmentLengths.size();
1298       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1299           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1300           resultSegmentLengths.size(), resultSegmentLengths.data());
1301       (*attributes)["result_segment_sizes"] =
1302           PyAttribute(context, segmentLengthAttr);
1303     }
1304 
1305     // Add operand_segment_sizes attribute.
1306     if (!operandSegmentLengths.empty()) {
1307       int64_t size = operandSegmentLengths.size();
1308       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1309           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1310           operandSegmentLengths.size(), operandSegmentLengths.data());
1311       (*attributes)["operand_segment_sizes"] =
1312           PyAttribute(context, segmentLengthAttr);
1313     }
1314   }
1315 
1316   // Delegate to create.
1317   return PyOperation::create(std::move(name),
1318                              /*results=*/std::move(resultTypes),
1319                              /*operands=*/std::move(operands),
1320                              /*attributes=*/std::move(attributes),
1321                              /*successors=*/std::move(successors),
1322                              /*regions=*/*regions, location, maybeIp);
1323 }
1324 
1325 PyOpView::PyOpView(py::object operationObject)
1326     // Casting through the PyOperationBase base-class and then back to the
1327     // Operation lets us accept any PyOperationBase subclass.
1328     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1329       operationObject(operation.getRef().getObject()) {}
1330 
1331 py::object PyOpView::createRawSubclass(py::object userClass) {
1332   // This is... a little gross. The typical pattern is to have a pure python
1333   // class that extends OpView like:
1334   //   class AddFOp(_cext.ir.OpView):
1335   //     def __init__(self, loc, lhs, rhs):
1336   //       operation = loc.context.create_operation(
1337   //           "addf", lhs, rhs, results=[lhs.type])
1338   //       super().__init__(operation)
1339   //
1340   // I.e. The goal of the user facing type is to provide a nice constructor
1341   // that has complete freedom for the op under construction. This is at odds
1342   // with our other desire to sometimes create this object by just passing an
1343   // operation (to initialize the base class). We could do *arg and **kwargs
1344   // munging to try to make it work, but instead, we synthesize a new class
1345   // on the fly which extends this user class (AddFOp in this example) and
1346   // *give it* the base class's __init__ method, thus bypassing the
1347   // intermediate subclass's __init__ method entirely. While slightly,
1348   // underhanded, this is safe/legal because the type hierarchy has not changed
1349   // (we just added a new leaf) and we aren't mucking around with __new__.
1350   // Typically, this new class will be stored on the original as "_Raw" and will
1351   // be used for casts and other things that need a variant of the class that
1352   // is initialized purely from an operation.
1353   py::object parentMetaclass =
1354       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1355   py::dict attributes;
1356   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1357   // now.
1358   //   auto opViewType = py::type::of<PyOpView>();
1359   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1360   attributes["__init__"] = opViewType.attr("__init__");
1361   py::str origName = userClass.attr("__name__");
1362   py::str newName = py::str("_") + origName;
1363   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1364 }
1365 
1366 //------------------------------------------------------------------------------
1367 // PyInsertionPoint.
1368 //------------------------------------------------------------------------------
1369 
1370 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1371 
1372 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1373     : refOperation(beforeOperationBase.getOperation().getRef()),
1374       block((*refOperation)->getBlock()) {}
1375 
1376 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1377   PyOperation &operation = operationBase.getOperation();
1378   if (operation.isAttached())
1379     throw SetPyError(PyExc_ValueError,
1380                      "Attempt to insert operation that is already attached");
1381   block.getParentOperation()->checkValid();
1382   MlirOperation beforeOp = {nullptr};
1383   if (refOperation) {
1384     // Insert before operation.
1385     (*refOperation)->checkValid();
1386     beforeOp = (*refOperation)->get();
1387   } else {
1388     // Insert at end (before null) is only valid if the block does not
1389     // already end in a known terminator (violating this will cause assertion
1390     // failures later).
1391     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1392       throw py::index_error("Cannot insert operation at the end of a block "
1393                             "that already has a terminator. Did you mean to "
1394                             "use 'InsertionPoint.at_block_terminator(block)' "
1395                             "versus 'InsertionPoint(block)'?");
1396     }
1397   }
1398   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1399   operation.setAttached();
1400 }
1401 
1402 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1403   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1404   if (mlirOperationIsNull(firstOp)) {
1405     // Just insert at end.
1406     return PyInsertionPoint(block);
1407   }
1408 
1409   // Insert before first op.
1410   PyOperationRef firstOpRef = PyOperation::forOperation(
1411       block.getParentOperation()->getContext(), firstOp);
1412   return PyInsertionPoint{block, std::move(firstOpRef)};
1413 }
1414 
1415 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1416   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1417   if (mlirOperationIsNull(terminator))
1418     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1419   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1420       block.getParentOperation()->getContext(), terminator);
1421   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1422 }
1423 
1424 py::object PyInsertionPoint::contextEnter() {
1425   return PyThreadContextEntry::pushInsertionPoint(*this);
1426 }
1427 
1428 void PyInsertionPoint::contextExit(pybind11::object excType,
1429                                    pybind11::object excVal,
1430                                    pybind11::object excTb) {
1431   PyThreadContextEntry::popInsertionPoint(*this);
1432 }
1433 
1434 //------------------------------------------------------------------------------
1435 // PyAttribute.
1436 //------------------------------------------------------------------------------
1437 
1438 bool PyAttribute::operator==(const PyAttribute &other) {
1439   return mlirAttributeEqual(attr, other.attr);
1440 }
1441 
1442 py::object PyAttribute::getCapsule() {
1443   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1444 }
1445 
1446 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1447   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1448   if (mlirAttributeIsNull(rawAttr))
1449     throw py::error_already_set();
1450   return PyAttribute(
1451       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1452 }
1453 
1454 //------------------------------------------------------------------------------
1455 // PyNamedAttribute.
1456 //------------------------------------------------------------------------------
1457 
1458 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1459     : ownedName(new std::string(std::move(ownedName))) {
1460   namedAttr = mlirNamedAttributeGet(
1461       mlirIdentifierGet(mlirAttributeGetContext(attr),
1462                         toMlirStringRef(*this->ownedName)),
1463       attr);
1464 }
1465 
1466 //------------------------------------------------------------------------------
1467 // PyType.
1468 //------------------------------------------------------------------------------
1469 
1470 bool PyType::operator==(const PyType &other) {
1471   return mlirTypeEqual(type, other.type);
1472 }
1473 
1474 py::object PyType::getCapsule() {
1475   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1476 }
1477 
1478 PyType PyType::createFromCapsule(py::object capsule) {
1479   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1480   if (mlirTypeIsNull(rawType))
1481     throw py::error_already_set();
1482   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1483                 rawType);
1484 }
1485 
1486 //------------------------------------------------------------------------------
1487 // PyValue and subclases.
1488 //------------------------------------------------------------------------------
1489 
1490 pybind11::object PyValue::getCapsule() {
1491   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1492 }
1493 
1494 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1495   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1496   if (mlirValueIsNull(value))
1497     throw py::error_already_set();
1498   MlirOperation owner;
1499   if (mlirValueIsAOpResult(value))
1500     owner = mlirOpResultGetOwner(value);
1501   if (mlirValueIsABlockArgument(value))
1502     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1503   if (mlirOperationIsNull(owner))
1504     throw py::error_already_set();
1505   MlirContext ctx = mlirOperationGetContext(owner);
1506   PyOperationRef ownerRef =
1507       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1508   return PyValue(ownerRef, value);
1509 }
1510 
1511 namespace {
1512 /// CRTP base class for Python MLIR values that subclass Value and should be
1513 /// castable from it. The value hierarchy is one level deep and is not supposed
1514 /// to accommodate other levels unless core MLIR changes.
1515 template <typename DerivedTy>
1516 class PyConcreteValue : public PyValue {
1517 public:
1518   // Derived classes must define statics for:
1519   //   IsAFunctionTy isaFunction
1520   //   const char *pyClassName
1521   // and redefine bindDerived.
1522   using ClassTy = py::class_<DerivedTy, PyValue>;
1523   using IsAFunctionTy = bool (*)(MlirValue);
1524 
1525   PyConcreteValue() = default;
1526   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1527       : PyValue(operationRef, value) {}
1528   PyConcreteValue(PyValue &orig)
1529       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1530 
1531   /// Attempts to cast the original value to the derived type and throws on
1532   /// type mismatches.
1533   static MlirValue castFrom(PyValue &orig) {
1534     if (!DerivedTy::isaFunction(orig.get())) {
1535       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1536       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1537                                              DerivedTy::pyClassName +
1538                                              " (from " + origRepr + ")");
1539     }
1540     return orig.get();
1541   }
1542 
1543   /// Binds the Python module objects to functions of this class.
1544   static void bind(py::module &m) {
1545     auto cls = ClassTy(m, DerivedTy::pyClassName);
1546     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1547     DerivedTy::bindDerived(cls);
1548   }
1549 
1550   /// Implemented by derived classes to add methods to the Python subclass.
1551   static void bindDerived(ClassTy &m) {}
1552 };
1553 
1554 /// Python wrapper for MlirBlockArgument.
1555 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1556 public:
1557   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1558   static constexpr const char *pyClassName = "BlockArgument";
1559   using PyConcreteValue::PyConcreteValue;
1560 
1561   static void bindDerived(ClassTy &c) {
1562     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1563       return PyBlock(self.getParentOperation(),
1564                      mlirBlockArgumentGetOwner(self.get()));
1565     });
1566     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1567       return mlirBlockArgumentGetArgNumber(self.get());
1568     });
1569     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1570       return mlirBlockArgumentSetType(self.get(), type);
1571     });
1572   }
1573 };
1574 
1575 /// Python wrapper for MlirOpResult.
1576 class PyOpResult : public PyConcreteValue<PyOpResult> {
1577 public:
1578   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1579   static constexpr const char *pyClassName = "OpResult";
1580   using PyConcreteValue::PyConcreteValue;
1581 
1582   static void bindDerived(ClassTy &c) {
1583     c.def_property_readonly("owner", [](PyOpResult &self) {
1584       assert(
1585           mlirOperationEqual(self.getParentOperation()->get(),
1586                              mlirOpResultGetOwner(self.get())) &&
1587           "expected the owner of the value in Python to match that in the IR");
1588       return self.getParentOperation().getObject();
1589     });
1590     c.def_property_readonly("result_number", [](PyOpResult &self) {
1591       return mlirOpResultGetResultNumber(self.get());
1592     });
1593   }
1594 };
1595 
1596 /// A list of block arguments. Internally, these are stored as consecutive
1597 /// elements, random access is cheap. The argument list is associated with the
1598 /// operation that contains the block (detached blocks are not allowed in
1599 /// Python bindings) and extends its lifetime.
1600 class PyBlockArgumentList {
1601 public:
1602   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1603       : operation(std::move(operation)), block(block) {}
1604 
1605   /// Returns the length of the block argument list.
1606   intptr_t dunderLen() {
1607     operation->checkValid();
1608     return mlirBlockGetNumArguments(block);
1609   }
1610 
1611   /// Returns `index`-th element of the block argument list.
1612   PyBlockArgument dunderGetItem(intptr_t index) {
1613     if (index < 0 || index >= dunderLen()) {
1614       throw SetPyError(PyExc_IndexError,
1615                        "attempt to access out of bounds region");
1616     }
1617     PyValue value(operation, mlirBlockGetArgument(block, index));
1618     return PyBlockArgument(value);
1619   }
1620 
1621   /// Defines a Python class in the bindings.
1622   static void bind(py::module &m) {
1623     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1624         .def("__len__", &PyBlockArgumentList::dunderLen)
1625         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1626   }
1627 
1628 private:
1629   PyOperationRef operation;
1630   MlirBlock block;
1631 };
1632 
1633 /// A list of operation operands. Internally, these are stored as consecutive
1634 /// elements, random access is cheap. The result list is associated with the
1635 /// operation whose results these are, and extends the lifetime of this
1636 /// operation.
1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1638 public:
1639   static constexpr const char *pyClassName = "OpOperandList";
1640 
1641   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1642                   intptr_t length = -1, intptr_t step = 1)
1643       : Sliceable(startIndex,
1644                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1645                                : length,
1646                   step),
1647         operation(operation) {}
1648 
1649   intptr_t getNumElements() {
1650     operation->checkValid();
1651     return mlirOperationGetNumOperands(operation->get());
1652   }
1653 
1654   PyValue getElement(intptr_t pos) {
1655     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1656   }
1657 
1658   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1659     return PyOpOperandList(operation, startIndex, length, step);
1660   }
1661 
1662   void dunderSetItem(intptr_t index, PyValue value) {
1663     index = wrapIndex(index);
1664     mlirOperationSetOperand(operation->get(), index, value.get());
1665   }
1666 
1667   static void bindDerived(ClassTy &c) {
1668     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1669   }
1670 
1671 private:
1672   PyOperationRef operation;
1673 };
1674 
1675 /// A list of operation results. Internally, these are stored as consecutive
1676 /// elements, random access is cheap. The result list is associated with the
1677 /// operation whose results these are, and extends the lifetime of this
1678 /// operation.
1679 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1680 public:
1681   static constexpr const char *pyClassName = "OpResultList";
1682 
1683   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1684                  intptr_t length = -1, intptr_t step = 1)
1685       : Sliceable(startIndex,
1686                   length == -1 ? mlirOperationGetNumResults(operation->get())
1687                                : length,
1688                   step),
1689         operation(operation) {}
1690 
1691   intptr_t getNumElements() {
1692     operation->checkValid();
1693     return mlirOperationGetNumResults(operation->get());
1694   }
1695 
1696   PyOpResult getElement(intptr_t index) {
1697     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1698     return PyOpResult(value);
1699   }
1700 
1701   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1702     return PyOpResultList(operation, startIndex, length, step);
1703   }
1704 
1705 private:
1706   PyOperationRef operation;
1707 };
1708 
1709 /// A list of operation attributes. Can be indexed by name, producing
1710 /// attributes, or by index, producing named attributes.
1711 class PyOpAttributeMap {
1712 public:
1713   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1714 
1715   PyAttribute dunderGetItemNamed(const std::string &name) {
1716     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1717                                                          toMlirStringRef(name));
1718     if (mlirAttributeIsNull(attr)) {
1719       throw SetPyError(PyExc_KeyError,
1720                        "attempt to access a non-existent attribute");
1721     }
1722     return PyAttribute(operation->getContext(), attr);
1723   }
1724 
1725   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1726     if (index < 0 || index >= dunderLen()) {
1727       throw SetPyError(PyExc_IndexError,
1728                        "attempt to access out of bounds attribute");
1729     }
1730     MlirNamedAttribute namedAttr =
1731         mlirOperationGetAttribute(operation->get(), index);
1732     return PyNamedAttribute(
1733         namedAttr.attribute,
1734         std::string(mlirIdentifierStr(namedAttr.name).data));
1735   }
1736 
1737   void dunderSetItem(const std::string &name, PyAttribute attr) {
1738     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1739                                     attr);
1740   }
1741 
1742   void dunderDelItem(const std::string &name) {
1743     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1744                                                      toMlirStringRef(name));
1745     if (!removed)
1746       throw SetPyError(PyExc_KeyError,
1747                        "attempt to delete a non-existent attribute");
1748   }
1749 
1750   intptr_t dunderLen() {
1751     return mlirOperationGetNumAttributes(operation->get());
1752   }
1753 
1754   bool dunderContains(const std::string &name) {
1755     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1756         operation->get(), toMlirStringRef(name)));
1757   }
1758 
1759   static void bind(py::module &m) {
1760     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1761         .def("__contains__", &PyOpAttributeMap::dunderContains)
1762         .def("__len__", &PyOpAttributeMap::dunderLen)
1763         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1764         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1765         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1766         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1767   }
1768 
1769 private:
1770   PyOperationRef operation;
1771 };
1772 
1773 } // end namespace
1774 
1775 //------------------------------------------------------------------------------
1776 // Populates the core exports of the 'ir' submodule.
1777 //------------------------------------------------------------------------------
1778 
1779 void mlir::python::populateIRCore(py::module &m) {
1780   //----------------------------------------------------------------------------
1781   // Mapping of MlirContext.
1782   //----------------------------------------------------------------------------
1783   py::class_<PyMlirContext>(m, "Context")
1784       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1785       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1786       .def("_get_context_again",
1787            [](PyMlirContext &self) {
1788              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1789              return ref.releaseObject();
1790            })
1791       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1792       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1793       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1794                              &PyMlirContext::getCapsule)
1795       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1796       .def("__enter__", &PyMlirContext::contextEnter)
1797       .def("__exit__", &PyMlirContext::contextExit)
1798       .def_property_readonly_static(
1799           "current",
1800           [](py::object & /*class*/) {
1801             auto *context = PyThreadContextEntry::getDefaultContext();
1802             if (!context)
1803               throw SetPyError(PyExc_ValueError, "No current Context");
1804             return context;
1805           },
1806           "Gets the Context bound to the current thread or raises ValueError")
1807       .def_property_readonly(
1808           "dialects",
1809           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1810           "Gets a container for accessing dialects by name")
1811       .def_property_readonly(
1812           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1813           "Alias for 'dialect'")
1814       .def(
1815           "get_dialect_descriptor",
1816           [=](PyMlirContext &self, std::string &name) {
1817             MlirDialect dialect = mlirContextGetOrLoadDialect(
1818                 self.get(), {name.data(), name.size()});
1819             if (mlirDialectIsNull(dialect)) {
1820               throw SetPyError(PyExc_ValueError,
1821                                Twine("Dialect '") + name + "' not found");
1822             }
1823             return PyDialectDescriptor(self.getRef(), dialect);
1824           },
1825           "Gets or loads a dialect by name, returning its descriptor object")
1826       .def_property(
1827           "allow_unregistered_dialects",
1828           [](PyMlirContext &self) -> bool {
1829             return mlirContextGetAllowUnregisteredDialects(self.get());
1830           },
1831           [](PyMlirContext &self, bool value) {
1832             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1833           })
1834       .def("enable_multithreading",
1835            [](PyMlirContext &self, bool enable) {
1836              mlirContextEnableMultithreading(self.get(), enable);
1837            })
1838       .def("is_registered_operation",
1839            [](PyMlirContext &self, std::string &name) {
1840              return mlirContextIsRegisteredOperation(
1841                  self.get(), MlirStringRef{name.data(), name.size()});
1842            });
1843 
1844   //----------------------------------------------------------------------------
1845   // Mapping of PyDialectDescriptor
1846   //----------------------------------------------------------------------------
1847   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1848       .def_property_readonly("namespace",
1849                              [](PyDialectDescriptor &self) {
1850                                MlirStringRef ns =
1851                                    mlirDialectGetNamespace(self.get());
1852                                return py::str(ns.data, ns.length);
1853                              })
1854       .def("__repr__", [](PyDialectDescriptor &self) {
1855         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1856         std::string repr("<DialectDescriptor ");
1857         repr.append(ns.data, ns.length);
1858         repr.append(">");
1859         return repr;
1860       });
1861 
1862   //----------------------------------------------------------------------------
1863   // Mapping of PyDialects
1864   //----------------------------------------------------------------------------
1865   py::class_<PyDialects>(m, "Dialects")
1866       .def("__getitem__",
1867            [=](PyDialects &self, std::string keyName) {
1868              MlirDialect dialect =
1869                  self.getDialectForKey(keyName, /*attrError=*/false);
1870              py::object descriptor =
1871                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1872              return createCustomDialectWrapper(keyName, std::move(descriptor));
1873            })
1874       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1875         MlirDialect dialect =
1876             self.getDialectForKey(attrName, /*attrError=*/true);
1877         py::object descriptor =
1878             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1879         return createCustomDialectWrapper(attrName, std::move(descriptor));
1880       });
1881 
1882   //----------------------------------------------------------------------------
1883   // Mapping of PyDialect
1884   //----------------------------------------------------------------------------
1885   py::class_<PyDialect>(m, "Dialect")
1886       .def(py::init<py::object>(), "descriptor")
1887       .def_property_readonly(
1888           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1889       .def("__repr__", [](py::object self) {
1890         auto clazz = self.attr("__class__");
1891         return py::str("<Dialect ") +
1892                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1893                clazz.attr("__module__") + py::str(".") +
1894                clazz.attr("__name__") + py::str(")>");
1895       });
1896 
1897   //----------------------------------------------------------------------------
1898   // Mapping of Location
1899   //----------------------------------------------------------------------------
1900   py::class_<PyLocation>(m, "Location")
1901       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1902       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1903       .def("__enter__", &PyLocation::contextEnter)
1904       .def("__exit__", &PyLocation::contextExit)
1905       .def("__eq__",
1906            [](PyLocation &self, PyLocation &other) -> bool {
1907              return mlirLocationEqual(self, other);
1908            })
1909       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1910       .def_property_readonly_static(
1911           "current",
1912           [](py::object & /*class*/) {
1913             auto *loc = PyThreadContextEntry::getDefaultLocation();
1914             if (!loc)
1915               throw SetPyError(PyExc_ValueError, "No current Location");
1916             return loc;
1917           },
1918           "Gets the Location bound to the current thread or raises ValueError")
1919       .def_static(
1920           "unknown",
1921           [](DefaultingPyMlirContext context) {
1922             return PyLocation(context->getRef(),
1923                               mlirLocationUnknownGet(context->get()));
1924           },
1925           py::arg("context") = py::none(),
1926           "Gets a Location representing an unknown location")
1927       .def_static(
1928           "file",
1929           [](std::string filename, int line, int col,
1930              DefaultingPyMlirContext context) {
1931             return PyLocation(
1932                 context->getRef(),
1933                 mlirLocationFileLineColGet(
1934                     context->get(), toMlirStringRef(filename), line, col));
1935           },
1936           py::arg("filename"), py::arg("line"), py::arg("col"),
1937           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1938       .def_property_readonly(
1939           "context",
1940           [](PyLocation &self) { return self.getContext().getObject(); },
1941           "Context that owns the Location")
1942       .def("__repr__", [](PyLocation &self) {
1943         PyPrintAccumulator printAccum;
1944         mlirLocationPrint(self, printAccum.getCallback(),
1945                           printAccum.getUserData());
1946         return printAccum.join();
1947       });
1948 
1949   //----------------------------------------------------------------------------
1950   // Mapping of Module
1951   //----------------------------------------------------------------------------
1952   py::class_<PyModule>(m, "Module")
1953       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1954       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1955       .def_static(
1956           "parse",
1957           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1958             MlirModule module = mlirModuleCreateParse(
1959                 context->get(), toMlirStringRef(moduleAsm));
1960             // TODO: Rework error reporting once diagnostic engine is exposed
1961             // in C API.
1962             if (mlirModuleIsNull(module)) {
1963               throw SetPyError(
1964                   PyExc_ValueError,
1965                   "Unable to parse module assembly (see diagnostics)");
1966             }
1967             return PyModule::forModule(module).releaseObject();
1968           },
1969           py::arg("asm"), py::arg("context") = py::none(),
1970           kModuleParseDocstring)
1971       .def_static(
1972           "create",
1973           [](DefaultingPyLocation loc) {
1974             MlirModule module = mlirModuleCreateEmpty(loc);
1975             return PyModule::forModule(module).releaseObject();
1976           },
1977           py::arg("loc") = py::none(), "Creates an empty module")
1978       .def_property_readonly(
1979           "context",
1980           [](PyModule &self) { return self.getContext().getObject(); },
1981           "Context that created the Module")
1982       .def_property_readonly(
1983           "operation",
1984           [](PyModule &self) {
1985             return PyOperation::forOperation(self.getContext(),
1986                                              mlirModuleGetOperation(self.get()),
1987                                              self.getRef().releaseObject())
1988                 .releaseObject();
1989           },
1990           "Accesses the module as an operation")
1991       .def_property_readonly(
1992           "body",
1993           [](PyModule &self) {
1994             PyOperationRef module_op = PyOperation::forOperation(
1995                 self.getContext(), mlirModuleGetOperation(self.get()),
1996                 self.getRef().releaseObject());
1997             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
1998             return returnBlock;
1999           },
2000           "Return the block for this module")
2001       .def(
2002           "dump",
2003           [](PyModule &self) {
2004             mlirOperationDump(mlirModuleGetOperation(self.get()));
2005           },
2006           kDumpDocstring)
2007       .def(
2008           "__str__",
2009           [](PyModule &self) {
2010             MlirOperation operation = mlirModuleGetOperation(self.get());
2011             PyPrintAccumulator printAccum;
2012             mlirOperationPrint(operation, printAccum.getCallback(),
2013                                printAccum.getUserData());
2014             return printAccum.join();
2015           },
2016           kOperationStrDunderDocstring);
2017 
2018   //----------------------------------------------------------------------------
2019   // Mapping of Operation.
2020   //----------------------------------------------------------------------------
2021   py::class_<PyOperationBase>(m, "_OperationBase")
2022       .def("__eq__",
2023            [](PyOperationBase &self, PyOperationBase &other) {
2024              return &self.getOperation() == &other.getOperation();
2025            })
2026       .def("__eq__",
2027            [](PyOperationBase &self, py::object other) { return false; })
2028       .def_property_readonly("attributes",
2029                              [](PyOperationBase &self) {
2030                                return PyOpAttributeMap(
2031                                    self.getOperation().getRef());
2032                              })
2033       .def_property_readonly("operands",
2034                              [](PyOperationBase &self) {
2035                                return PyOpOperandList(
2036                                    self.getOperation().getRef());
2037                              })
2038       .def_property_readonly("regions",
2039                              [](PyOperationBase &self) {
2040                                return PyRegionList(
2041                                    self.getOperation().getRef());
2042                              })
2043       .def_property_readonly(
2044           "results",
2045           [](PyOperationBase &self) {
2046             return PyOpResultList(self.getOperation().getRef());
2047           },
2048           "Returns the list of Operation results.")
2049       .def_property_readonly(
2050           "result",
2051           [](PyOperationBase &self) {
2052             auto &operation = self.getOperation();
2053             auto numResults = mlirOperationGetNumResults(operation);
2054             if (numResults != 1) {
2055               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2056               throw SetPyError(
2057                   PyExc_ValueError,
2058                   Twine("Cannot call .result on operation ") +
2059                       StringRef(name.data, name.length) + " which has " +
2060                       Twine(numResults) +
2061                       " results (it is only valid for operations with a "
2062                       "single result)");
2063             }
2064             return PyOpResult(operation.getRef(),
2065                               mlirOperationGetResult(operation, 0));
2066           },
2067           "Shortcut to get an op result if it has only one (throws an error "
2068           "otherwise).")
2069       .def("__iter__",
2070            [](PyOperationBase &self) {
2071              return PyRegionIterator(self.getOperation().getRef());
2072            })
2073       .def(
2074           "__str__",
2075           [](PyOperationBase &self) {
2076             return self.getAsm(/*binary=*/false,
2077                                /*largeElementsLimit=*/llvm::None,
2078                                /*enableDebugInfo=*/false,
2079                                /*prettyDebugInfo=*/false,
2080                                /*printGenericOpForm=*/false,
2081                                /*useLocalScope=*/false);
2082           },
2083           "Returns the assembly form of the operation.")
2084       .def("print", &PyOperationBase::print,
2085            // Careful: Lots of arguments must match up with print method.
2086            py::arg("file") = py::none(), py::arg("binary") = false,
2087            py::arg("large_elements_limit") = py::none(),
2088            py::arg("enable_debug_info") = false,
2089            py::arg("pretty_debug_info") = false,
2090            py::arg("print_generic_op_form") = false,
2091            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2092       .def("get_asm", &PyOperationBase::getAsm,
2093            // Careful: Lots of arguments must match up with get_asm method.
2094            py::arg("binary") = false,
2095            py::arg("large_elements_limit") = py::none(),
2096            py::arg("enable_debug_info") = false,
2097            py::arg("pretty_debug_info") = false,
2098            py::arg("print_generic_op_form") = false,
2099            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2100       .def(
2101           "verify",
2102           [](PyOperationBase &self) {
2103             return mlirOperationVerify(self.getOperation());
2104           },
2105           "Verify the operation and return true if it passes, false if it "
2106           "fails.");
2107 
2108   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2109       .def_static("create", &PyOperation::create, py::arg("name"),
2110                   py::arg("results") = py::none(),
2111                   py::arg("operands") = py::none(),
2112                   py::arg("attributes") = py::none(),
2113                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2114                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2115                   kOperationCreateDocstring)
2116       .def("erase", &PyOperation::erase)
2117       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2118                              &PyOperation::getCapsule)
2119       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2120       .def_property_readonly("name",
2121                              [](PyOperation &self) {
2122                                self.checkValid();
2123                                MlirOperation operation = self.get();
2124                                MlirStringRef name = mlirIdentifierStr(
2125                                    mlirOperationGetName(operation));
2126                                return py::str(name.data, name.length);
2127                              })
2128       .def_property_readonly(
2129           "context",
2130           [](PyOperation &self) {
2131             self.checkValid();
2132             return self.getContext().getObject();
2133           },
2134           "Context that owns the Operation")
2135       .def_property_readonly("opview", &PyOperation::createOpView);
2136 
2137   auto opViewClass =
2138       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2139           .def(py::init<py::object>())
2140           .def_property_readonly("operation", &PyOpView::getOperationObject)
2141           .def_property_readonly(
2142               "context",
2143               [](PyOpView &self) {
2144                 return self.getOperation().getContext().getObject();
2145               },
2146               "Context that owns the Operation")
2147           .def("__str__", [](PyOpView &self) {
2148             return py::str(self.getOperationObject());
2149           });
2150   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2151   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2152   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2153   opViewClass.attr("build_generic") = classmethod(
2154       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2155       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2156       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2157       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2158       "Builds a specific, generated OpView based on class level attributes.");
2159 
2160   //----------------------------------------------------------------------------
2161   // Mapping of PyRegion.
2162   //----------------------------------------------------------------------------
2163   py::class_<PyRegion>(m, "Region")
2164       .def_property_readonly(
2165           "blocks",
2166           [](PyRegion &self) {
2167             return PyBlockList(self.getParentOperation(), self.get());
2168           },
2169           "Returns a forward-optimized sequence of blocks.")
2170       .def(
2171           "__iter__",
2172           [](PyRegion &self) {
2173             self.checkValid();
2174             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2175             return PyBlockIterator(self.getParentOperation(), firstBlock);
2176           },
2177           "Iterates over blocks in the region.")
2178       .def("__eq__",
2179            [](PyRegion &self, PyRegion &other) {
2180              return self.get().ptr == other.get().ptr;
2181            })
2182       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2183 
2184   //----------------------------------------------------------------------------
2185   // Mapping of PyBlock.
2186   //----------------------------------------------------------------------------
2187   py::class_<PyBlock>(m, "Block")
2188       .def_property_readonly(
2189           "arguments",
2190           [](PyBlock &self) {
2191             return PyBlockArgumentList(self.getParentOperation(), self.get());
2192           },
2193           "Returns a list of block arguments.")
2194       .def_property_readonly(
2195           "operations",
2196           [](PyBlock &self) {
2197             return PyOperationList(self.getParentOperation(), self.get());
2198           },
2199           "Returns a forward-optimized sequence of operations.")
2200       .def(
2201           "__iter__",
2202           [](PyBlock &self) {
2203             self.checkValid();
2204             MlirOperation firstOperation =
2205                 mlirBlockGetFirstOperation(self.get());
2206             return PyOperationIterator(self.getParentOperation(),
2207                                        firstOperation);
2208           },
2209           "Iterates over operations in the block.")
2210       .def("__eq__",
2211            [](PyBlock &self, PyBlock &other) {
2212              return self.get().ptr == other.get().ptr;
2213            })
2214       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2215       .def(
2216           "__str__",
2217           [](PyBlock &self) {
2218             self.checkValid();
2219             PyPrintAccumulator printAccum;
2220             mlirBlockPrint(self.get(), printAccum.getCallback(),
2221                            printAccum.getUserData());
2222             return printAccum.join();
2223           },
2224           "Returns the assembly form of the block.");
2225 
2226   //----------------------------------------------------------------------------
2227   // Mapping of PyInsertionPoint.
2228   //----------------------------------------------------------------------------
2229 
2230   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2231       .def(py::init<PyBlock &>(), py::arg("block"),
2232            "Inserts after the last operation but still inside the block.")
2233       .def("__enter__", &PyInsertionPoint::contextEnter)
2234       .def("__exit__", &PyInsertionPoint::contextExit)
2235       .def_property_readonly_static(
2236           "current",
2237           [](py::object & /*class*/) {
2238             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2239             if (!ip)
2240               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2241             return ip;
2242           },
2243           "Gets the InsertionPoint bound to the current thread or raises "
2244           "ValueError if none has been set")
2245       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2246            "Inserts before a referenced operation.")
2247       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2248                   py::arg("block"), "Inserts at the beginning of the block.")
2249       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2250                   py::arg("block"), "Inserts before the block terminator.")
2251       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2252            "Inserts an operation.");
2253 
2254   //----------------------------------------------------------------------------
2255   // Mapping of PyAttribute.
2256   //----------------------------------------------------------------------------
2257   py::class_<PyAttribute>(m, "Attribute")
2258       // Delegate to the PyAttribute copy constructor, which will also lifetime
2259       // extend the backing context which owns the MlirAttribute.
2260       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2261            "Casts the passed attribute to the generic Attribute")
2262       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2263                              &PyAttribute::getCapsule)
2264       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2265       .def_static(
2266           "parse",
2267           [](std::string attrSpec, DefaultingPyMlirContext context) {
2268             MlirAttribute type = mlirAttributeParseGet(
2269                 context->get(), toMlirStringRef(attrSpec));
2270             // TODO: Rework error reporting once diagnostic engine is exposed
2271             // in C API.
2272             if (mlirAttributeIsNull(type)) {
2273               throw SetPyError(PyExc_ValueError,
2274                                Twine("Unable to parse attribute: '") +
2275                                    attrSpec + "'");
2276             }
2277             return PyAttribute(context->getRef(), type);
2278           },
2279           py::arg("asm"), py::arg("context") = py::none(),
2280           "Parses an attribute from an assembly form")
2281       .def_property_readonly(
2282           "context",
2283           [](PyAttribute &self) { return self.getContext().getObject(); },
2284           "Context that owns the Attribute")
2285       .def_property_readonly("type",
2286                              [](PyAttribute &self) {
2287                                return PyType(self.getContext()->getRef(),
2288                                              mlirAttributeGetType(self));
2289                              })
2290       .def(
2291           "get_named",
2292           [](PyAttribute &self, std::string name) {
2293             return PyNamedAttribute(self, std::move(name));
2294           },
2295           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2296       .def("__eq__",
2297            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2298       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2299       .def(
2300           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2301           kDumpDocstring)
2302       .def(
2303           "__str__",
2304           [](PyAttribute &self) {
2305             PyPrintAccumulator printAccum;
2306             mlirAttributePrint(self, printAccum.getCallback(),
2307                                printAccum.getUserData());
2308             return printAccum.join();
2309           },
2310           "Returns the assembly form of the Attribute.")
2311       .def("__repr__", [](PyAttribute &self) {
2312         // Generally, assembly formats are not printed for __repr__ because
2313         // this can cause exceptionally long debug output and exceptions.
2314         // However, attribute values are generally considered useful and are
2315         // printed. This may need to be re-evaluated if debug dumps end up
2316         // being excessive.
2317         PyPrintAccumulator printAccum;
2318         printAccum.parts.append("Attribute(");
2319         mlirAttributePrint(self, printAccum.getCallback(),
2320                            printAccum.getUserData());
2321         printAccum.parts.append(")");
2322         return printAccum.join();
2323       });
2324 
2325   //----------------------------------------------------------------------------
2326   // Mapping of PyNamedAttribute
2327   //----------------------------------------------------------------------------
2328   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2329       .def("__repr__",
2330            [](PyNamedAttribute &self) {
2331              PyPrintAccumulator printAccum;
2332              printAccum.parts.append("NamedAttribute(");
2333              printAccum.parts.append(
2334                  mlirIdentifierStr(self.namedAttr.name).data);
2335              printAccum.parts.append("=");
2336              mlirAttributePrint(self.namedAttr.attribute,
2337                                 printAccum.getCallback(),
2338                                 printAccum.getUserData());
2339              printAccum.parts.append(")");
2340              return printAccum.join();
2341            })
2342       .def_property_readonly(
2343           "name",
2344           [](PyNamedAttribute &self) {
2345             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2346                            mlirIdentifierStr(self.namedAttr.name).length);
2347           },
2348           "The name of the NamedAttribute binding")
2349       .def_property_readonly(
2350           "attr",
2351           [](PyNamedAttribute &self) {
2352             // TODO: When named attribute is removed/refactored, also remove
2353             // this constructor (it does an inefficient table lookup).
2354             auto contextRef = PyMlirContext::forContext(
2355                 mlirAttributeGetContext(self.namedAttr.attribute));
2356             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2357           },
2358           py::keep_alive<0, 1>(),
2359           "The underlying generic attribute of the NamedAttribute binding");
2360 
2361   //----------------------------------------------------------------------------
2362   // Mapping of PyType.
2363   //----------------------------------------------------------------------------
2364   py::class_<PyType>(m, "Type")
2365       // Delegate to the PyType copy constructor, which will also lifetime
2366       // extend the backing context which owns the MlirType.
2367       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2368            "Casts the passed type to the generic Type")
2369       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2370       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2371       .def_static(
2372           "parse",
2373           [](std::string typeSpec, DefaultingPyMlirContext context) {
2374             MlirType type =
2375                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2376             // TODO: Rework error reporting once diagnostic engine is exposed
2377             // in C API.
2378             if (mlirTypeIsNull(type)) {
2379               throw SetPyError(PyExc_ValueError,
2380                                Twine("Unable to parse type: '") + typeSpec +
2381                                    "'");
2382             }
2383             return PyType(context->getRef(), type);
2384           },
2385           py::arg("asm"), py::arg("context") = py::none(),
2386           kContextParseTypeDocstring)
2387       .def_property_readonly(
2388           "context", [](PyType &self) { return self.getContext().getObject(); },
2389           "Context that owns the Type")
2390       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2391       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2392       .def(
2393           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2394       .def(
2395           "__str__",
2396           [](PyType &self) {
2397             PyPrintAccumulator printAccum;
2398             mlirTypePrint(self, printAccum.getCallback(),
2399                           printAccum.getUserData());
2400             return printAccum.join();
2401           },
2402           "Returns the assembly form of the type.")
2403       .def("__repr__", [](PyType &self) {
2404         // Generally, assembly formats are not printed for __repr__ because
2405         // this can cause exceptionally long debug output and exceptions.
2406         // However, types are an exception as they typically have compact
2407         // assembly forms and printing them is useful.
2408         PyPrintAccumulator printAccum;
2409         printAccum.parts.append("Type(");
2410         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2411         printAccum.parts.append(")");
2412         return printAccum.join();
2413       });
2414 
2415   //----------------------------------------------------------------------------
2416   // Mapping of Value.
2417   //----------------------------------------------------------------------------
2418   py::class_<PyValue>(m, "Value")
2419       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2420       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2421       .def_property_readonly(
2422           "context",
2423           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2424           "Context in which the value lives.")
2425       .def(
2426           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2427           kDumpDocstring)
2428       .def("__eq__",
2429            [](PyValue &self, PyValue &other) {
2430              return self.get().ptr == other.get().ptr;
2431            })
2432       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2433       .def(
2434           "__str__",
2435           [](PyValue &self) {
2436             PyPrintAccumulator printAccum;
2437             printAccum.parts.append("Value(");
2438             mlirValuePrint(self.get(), printAccum.getCallback(),
2439                            printAccum.getUserData());
2440             printAccum.parts.append(")");
2441             return printAccum.join();
2442           },
2443           kValueDunderStrDocstring)
2444       .def_property_readonly("type", [](PyValue &self) {
2445         return PyType(self.getParentOperation()->getContext(),
2446                       mlirValueGetType(self.get()));
2447       });
2448   PyBlockArgument::bind(m);
2449   PyOpResult::bind(m);
2450 
2451   // Container bindings.
2452   PyBlockArgumentList::bind(m);
2453   PyBlockIterator::bind(m);
2454   PyBlockList::bind(m);
2455   PyOperationIterator::bind(m);
2456   PyOperationList::bind(m);
2457   PyOpAttributeMap::bind(m);
2458   PyOpOperandList::bind(m);
2459   PyOpResultList::bind(m);
2460   PyRegionIterator::bind(m);
2461   PyRegionList::bind(m);
2462 
2463   // Debug bindings.
2464   PyGlobalDebugFlag::bind(m);
2465 }
2466