1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
183   PyRegionIterator &dunderIter() { return *this; }
184 
185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator")
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence")
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
241   PyBlockIterator &dunderIter() { return *this; }
242 
243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator")
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList")
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
337   PyOperationIterator &dunderIter() { return *this; }
338 
339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator")
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList")
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
647                                                     {key.data(), key.size()});
648   if (mlirDialectIsNull(dialect)) {
649     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
650                      Twine("Dialect '") + key + "' not found");
651   }
652   return dialect;
653 }
654 
655 //------------------------------------------------------------------------------
656 // PyLocation
657 //------------------------------------------------------------------------------
658 
659 py::object PyLocation::getCapsule() {
660   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
661 }
662 
663 PyLocation PyLocation::createFromCapsule(py::object capsule) {
664   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
665   if (mlirLocationIsNull(rawLoc))
666     throw py::error_already_set();
667   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
668                     rawLoc);
669 }
670 
671 py::object PyLocation::contextEnter() {
672   return PyThreadContextEntry::pushLocation(*this);
673 }
674 
675 void PyLocation::contextExit(py::object excType, py::object excVal,
676                              py::object excTb) {
677   PyThreadContextEntry::popLocation(*this);
678 }
679 
680 PyLocation &DefaultingPyLocation::resolve() {
681   auto *location = PyThreadContextEntry::getDefaultLocation();
682   if (!location) {
683     throw SetPyError(
684         PyExc_RuntimeError,
685         "An MLIR function requires a Location but none was provided in the "
686         "call or from the surrounding environment. Either pass to the function "
687         "with a 'loc=' argument or establish a default using 'with loc:'");
688   }
689   return *location;
690 }
691 
692 //------------------------------------------------------------------------------
693 // PyModule
694 //------------------------------------------------------------------------------
695 
696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
697     : BaseContextObject(std::move(contextRef)), module(module) {}
698 
699 PyModule::~PyModule() {
700   py::gil_scoped_acquire acquire;
701   auto &liveModules = getContext()->liveModules;
702   assert(liveModules.count(module.ptr) == 1 &&
703          "destroying module not in live map");
704   liveModules.erase(module.ptr);
705   mlirModuleDestroy(module);
706 }
707 
708 PyModuleRef PyModule::forModule(MlirModule module) {
709   MlirContext context = mlirModuleGetContext(module);
710   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
711 
712   py::gil_scoped_acquire acquire;
713   auto &liveModules = contextRef->liveModules;
714   auto it = liveModules.find(module.ptr);
715   if (it == liveModules.end()) {
716     // Create.
717     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
718     // Note that the default return value policy on cast is automatic_reference,
719     // which does not take ownership (delete will not be called).
720     // Just be explicit.
721     py::object pyRef =
722         py::cast(unownedModule, py::return_value_policy::take_ownership);
723     unownedModule->handle = pyRef;
724     liveModules[module.ptr] =
725         std::make_pair(unownedModule->handle, unownedModule);
726     return PyModuleRef(unownedModule, std::move(pyRef));
727   }
728   // Use existing.
729   PyModule *existing = it->second.second;
730   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
731   return PyModuleRef(existing, std::move(pyRef));
732 }
733 
734 py::object PyModule::createFromCapsule(py::object capsule) {
735   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
736   if (mlirModuleIsNull(rawModule))
737     throw py::error_already_set();
738   return forModule(rawModule).releaseObject();
739 }
740 
741 py::object PyModule::getCapsule() {
742   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
743 }
744 
745 //------------------------------------------------------------------------------
746 // PyOperation
747 //------------------------------------------------------------------------------
748 
749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
750     : BaseContextObject(std::move(contextRef)), operation(operation) {}
751 
752 PyOperation::~PyOperation() {
753   // If the operation has already been invalidated there is nothing to do.
754   if (!valid)
755     return;
756   auto &liveOperations = getContext()->liveOperations;
757   assert(liveOperations.count(operation.ptr) == 1 &&
758          "destroying operation not in live map");
759   liveOperations.erase(operation.ptr);
760   if (!isAttached()) {
761     mlirOperationDestroy(operation);
762   }
763 }
764 
765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
766                                            MlirOperation operation,
767                                            py::object parentKeepAlive) {
768   auto &liveOperations = contextRef->liveOperations;
769   // Create.
770   PyOperation *unownedOperation =
771       new PyOperation(std::move(contextRef), operation);
772   // Note that the default return value policy on cast is automatic_reference,
773   // which does not take ownership (delete will not be called).
774   // Just be explicit.
775   py::object pyRef =
776       py::cast(unownedOperation, py::return_value_policy::take_ownership);
777   unownedOperation->handle = pyRef;
778   if (parentKeepAlive) {
779     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
780   }
781   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
782   return PyOperationRef(unownedOperation, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
786                                          MlirOperation operation,
787                                          py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   auto it = liveOperations.find(operation.ptr);
790   if (it == liveOperations.end()) {
791     // Create.
792     return createInstance(std::move(contextRef), operation,
793                           std::move(parentKeepAlive));
794   }
795   // Use existing.
796   PyOperation *existing = it->second.second;
797   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
798   return PyOperationRef(existing, std::move(pyRef));
799 }
800 
801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
802                                            MlirOperation operation,
803                                            py::object parentKeepAlive) {
804   auto &liveOperations = contextRef->liveOperations;
805   assert(liveOperations.count(operation.ptr) == 0 &&
806          "cannot create detached operation that already exists");
807   (void)liveOperations;
808 
809   PyOperationRef created = createInstance(std::move(contextRef), operation,
810                                           std::move(parentKeepAlive));
811   created->attached = false;
812   return created;
813 }
814 
815 void PyOperation::checkValid() const {
816   if (!valid) {
817     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
818   }
819 }
820 
821 void PyOperationBase::print(py::object fileObject, bool binary,
822                             llvm::Optional<int64_t> largeElementsLimit,
823                             bool enableDebugInfo, bool prettyDebugInfo,
824                             bool printGenericOpForm, bool useLocalScope) {
825   PyOperation &operation = getOperation();
826   operation.checkValid();
827   if (fileObject.is_none())
828     fileObject = py::module::import("sys").attr("stdout");
829 
830   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
831     fileObject.attr("write")("// Verification failed, printing generic form\n");
832     printGenericOpForm = true;
833   }
834 
835   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
836   if (largeElementsLimit)
837     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
838   if (enableDebugInfo)
839     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
840   if (printGenericOpForm)
841     mlirOpPrintingFlagsPrintGenericOpForm(flags);
842 
843   PyFileAccumulator accum(fileObject, binary);
844   py::gil_scoped_release();
845   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
846                               accum.getUserData());
847   mlirOpPrintingFlagsDestroy(flags);
848 }
849 
850 py::object PyOperationBase::getAsm(bool binary,
851                                    llvm::Optional<int64_t> largeElementsLimit,
852                                    bool enableDebugInfo, bool prettyDebugInfo,
853                                    bool printGenericOpForm,
854                                    bool useLocalScope) {
855   py::object fileObject;
856   if (binary) {
857     fileObject = py::module::import("io").attr("BytesIO")();
858   } else {
859     fileObject = py::module::import("io").attr("StringIO")();
860   }
861   print(fileObject, /*binary=*/binary,
862         /*largeElementsLimit=*/largeElementsLimit,
863         /*enableDebugInfo=*/enableDebugInfo,
864         /*prettyDebugInfo=*/prettyDebugInfo,
865         /*printGenericOpForm=*/printGenericOpForm,
866         /*useLocalScope=*/useLocalScope);
867 
868   return fileObject.attr("getvalue")();
869 }
870 
871 PyOperationRef PyOperation::getParentOperation() {
872   checkValid();
873   if (!isAttached())
874     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
875   MlirOperation operation = mlirOperationGetParentOperation(get());
876   if (mlirOperationIsNull(operation))
877     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
878   return PyOperation::forOperation(getContext(), operation);
879 }
880 
881 PyBlock PyOperation::getBlock() {
882   checkValid();
883   PyOperationRef parentOperation = getParentOperation();
884   MlirBlock block = mlirOperationGetBlock(get());
885   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
886   return PyBlock{std::move(parentOperation), block};
887 }
888 
889 py::object PyOperation::getCapsule() {
890   checkValid();
891   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
892 }
893 
894 py::object PyOperation::createFromCapsule(py::object capsule) {
895   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
896   if (mlirOperationIsNull(rawOperation))
897     throw py::error_already_set();
898   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
899   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
900       .releaseObject();
901 }
902 
903 py::object PyOperation::create(
904     std::string name, llvm::Optional<std::vector<PyType *>> results,
905     llvm::Optional<std::vector<PyValue *>> operands,
906     llvm::Optional<py::dict> attributes,
907     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
908     DefaultingPyLocation location, py::object maybeIp) {
909   llvm::SmallVector<MlirValue, 4> mlirOperands;
910   llvm::SmallVector<MlirType, 4> mlirResults;
911   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
912   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
913 
914   // General parameter validation.
915   if (regions < 0)
916     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
917 
918   // Unpack/validate operands.
919   if (operands) {
920     mlirOperands.reserve(operands->size());
921     for (PyValue *operand : *operands) {
922       if (!operand)
923         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
924       mlirOperands.push_back(operand->get());
925     }
926   }
927 
928   // Unpack/validate results.
929   if (results) {
930     mlirResults.reserve(results->size());
931     for (PyType *result : *results) {
932       // TODO: Verify result type originate from the same context.
933       if (!result)
934         throw SetPyError(PyExc_ValueError, "result type cannot be None");
935       mlirResults.push_back(*result);
936     }
937   }
938   // Unpack/validate attributes.
939   if (attributes) {
940     mlirAttributes.reserve(attributes->size());
941     for (auto &it : *attributes) {
942       std::string key;
943       try {
944         key = it.first.cast<std::string>();
945       } catch (py::cast_error &err) {
946         std::string msg = "Invalid attribute key (not a string) when "
947                           "attempting to create the operation \"" +
948                           name + "\" (" + err.what() + ")";
949         throw py::cast_error(msg);
950       }
951       try {
952         auto &attribute = it.second.cast<PyAttribute &>();
953         // TODO: Verify attribute originates from the same context.
954         mlirAttributes.emplace_back(std::move(key), attribute);
955       } catch (py::reference_cast_error &) {
956         // This exception seems thrown when the value is "None".
957         std::string msg =
958             "Found an invalid (`None`?) attribute value for the key \"" + key +
959             "\" when attempting to create the operation \"" + name + "\"";
960         throw py::cast_error(msg);
961       } catch (py::cast_error &err) {
962         std::string msg = "Invalid attribute value for the key \"" + key +
963                           "\" when attempting to create the operation \"" +
964                           name + "\" (" + err.what() + ")";
965         throw py::cast_error(msg);
966       }
967     }
968   }
969   // Unpack/validate successors.
970   if (successors) {
971     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
972     mlirSuccessors.reserve(successors->size());
973     for (auto *successor : *successors) {
974       // TODO: Verify successor originate from the same context.
975       if (!successor)
976         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
977       mlirSuccessors.push_back(successor->get());
978     }
979   }
980 
981   // Apply unpacked/validated to the operation state. Beyond this
982   // point, exceptions cannot be thrown or else the state will leak.
983   MlirOperationState state =
984       mlirOperationStateGet(toMlirStringRef(name), location);
985   if (!mlirOperands.empty())
986     mlirOperationStateAddOperands(&state, mlirOperands.size(),
987                                   mlirOperands.data());
988   if (!mlirResults.empty())
989     mlirOperationStateAddResults(&state, mlirResults.size(),
990                                  mlirResults.data());
991   if (!mlirAttributes.empty()) {
992     // Note that the attribute names directly reference bytes in
993     // mlirAttributes, so that vector must not be changed from here
994     // on.
995     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
996     mlirNamedAttributes.reserve(mlirAttributes.size());
997     for (auto &it : mlirAttributes)
998       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
999           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1000                             toMlirStringRef(it.first)),
1001           it.second));
1002     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1003                                     mlirNamedAttributes.data());
1004   }
1005   if (!mlirSuccessors.empty())
1006     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1007                                     mlirSuccessors.data());
1008   if (regions) {
1009     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1010     mlirRegions.resize(regions);
1011     for (int i = 0; i < regions; ++i)
1012       mlirRegions[i] = mlirRegionCreate();
1013     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1014                                       mlirRegions.data());
1015   }
1016 
1017   // Construct the operation.
1018   MlirOperation operation = mlirOperationCreate(&state);
1019   PyOperationRef created =
1020       PyOperation::createDetached(location->getContext(), operation);
1021 
1022   // InsertPoint active?
1023   if (!maybeIp.is(py::cast(false))) {
1024     PyInsertionPoint *ip;
1025     if (maybeIp.is_none()) {
1026       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1027     } else {
1028       ip = py::cast<PyInsertionPoint *>(maybeIp);
1029     }
1030     if (ip)
1031       ip->insert(*created.get());
1032   }
1033 
1034   return created->createOpView();
1035 }
1036 
1037 py::object PyOperation::createOpView() {
1038   checkValid();
1039   MlirIdentifier ident = mlirOperationGetName(get());
1040   MlirStringRef identStr = mlirIdentifierStr(ident);
1041   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1042       StringRef(identStr.data, identStr.length));
1043   if (opViewClass)
1044     return (*opViewClass)(getRef().getObject());
1045   return py::cast(PyOpView(getRef().getObject()));
1046 }
1047 
1048 void PyOperation::erase() {
1049   checkValid();
1050   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1051   // Python reference to a child operation is live. All children should also
1052   // have their `valid` bit set to false.
1053   auto &liveOperations = getContext()->liveOperations;
1054   if (liveOperations.count(operation.ptr))
1055     liveOperations.erase(operation.ptr);
1056   mlirOperationDestroy(operation);
1057   valid = false;
1058 }
1059 
1060 //------------------------------------------------------------------------------
1061 // PyOpView
1062 //------------------------------------------------------------------------------
1063 
1064 py::object
1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1066                        py::list operandList,
1067                        llvm::Optional<py::dict> attributes,
1068                        llvm::Optional<std::vector<PyBlock *>> successors,
1069                        llvm::Optional<int> regions,
1070                        DefaultingPyLocation location, py::object maybeIp) {
1071   PyMlirContextRef context = location->getContext();
1072   // Class level operation construction metadata.
1073   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1074   // Operand and result segment specs are either none, which does no
1075   // variadic unpacking, or a list of ints with segment sizes, where each
1076   // element is either a positive number (typically 1 for a scalar) or -1 to
1077   // indicate that it is derived from the length of the same-indexed operand
1078   // or result (implying that it is a list at that position).
1079   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1080   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1081 
1082   std::vector<uint32_t> operandSegmentLengths;
1083   std::vector<uint32_t> resultSegmentLengths;
1084 
1085   // Validate/determine region count.
1086   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1087   int opMinRegionCount = std::get<0>(opRegionSpec);
1088   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1089   if (!regions) {
1090     regions = opMinRegionCount;
1091   }
1092   if (*regions < opMinRegionCount) {
1093     throw py::value_error(
1094         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1095          llvm::Twine(opMinRegionCount) +
1096          " regions but was built with regions=" + llvm::Twine(*regions))
1097             .str());
1098   }
1099   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1100     throw py::value_error(
1101         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1102          llvm::Twine(opMinRegionCount) +
1103          " regions but was built with regions=" + llvm::Twine(*regions))
1104             .str());
1105   }
1106 
1107   // Unpack results.
1108   std::vector<PyType *> resultTypes;
1109   resultTypes.reserve(resultTypeList.size());
1110   if (resultSegmentSpecObj.is_none()) {
1111     // Non-variadic result unpacking.
1112     for (auto it : llvm::enumerate(resultTypeList)) {
1113       try {
1114         resultTypes.push_back(py::cast<PyType *>(it.value()));
1115         if (!resultTypes.back())
1116           throw py::cast_error();
1117       } catch (py::cast_error &err) {
1118         throw py::value_error((llvm::Twine("Result ") +
1119                                llvm::Twine(it.index()) + " of operation \"" +
1120                                name + "\" must be a Type (" + err.what() + ")")
1121                                   .str());
1122       }
1123     }
1124   } else {
1125     // Sized result unpacking.
1126     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1127     if (resultSegmentSpec.size() != resultTypeList.size()) {
1128       throw py::value_error((llvm::Twine("Operation \"") + name +
1129                              "\" requires " +
1130                              llvm::Twine(resultSegmentSpec.size()) +
1131                              "result segments but was provided " +
1132                              llvm::Twine(resultTypeList.size()))
1133                                 .str());
1134     }
1135     resultSegmentLengths.reserve(resultTypeList.size());
1136     for (auto it :
1137          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1138       int segmentSpec = std::get<1>(it.value());
1139       if (segmentSpec == 1 || segmentSpec == 0) {
1140         // Unpack unary element.
1141         try {
1142           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1143           if (resultType) {
1144             resultTypes.push_back(resultType);
1145             resultSegmentLengths.push_back(1);
1146           } else if (segmentSpec == 0) {
1147             // Allowed to be optional.
1148             resultSegmentLengths.push_back(0);
1149           } else {
1150             throw py::cast_error("was None and result is not optional");
1151           }
1152         } catch (py::cast_error &err) {
1153           throw py::value_error((llvm::Twine("Result ") +
1154                                  llvm::Twine(it.index()) + " of operation \"" +
1155                                  name + "\" must be a Type (" + err.what() +
1156                                  ")")
1157                                     .str());
1158         }
1159       } else if (segmentSpec == -1) {
1160         // Unpack sequence by appending.
1161         try {
1162           if (std::get<0>(it.value()).is_none()) {
1163             // Treat it as an empty list.
1164             resultSegmentLengths.push_back(0);
1165           } else {
1166             // Unpack the list.
1167             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1168             for (py::object segmentItem : segment) {
1169               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1170               if (!resultTypes.back()) {
1171                 throw py::cast_error("contained a None item");
1172               }
1173             }
1174             resultSegmentLengths.push_back(segment.size());
1175           }
1176         } catch (std::exception &err) {
1177           // NOTE: Sloppy to be using a catch-all here, but there are at least
1178           // three different unrelated exceptions that can be thrown in the
1179           // above "casts". Just keep the scope above small and catch them all.
1180           throw py::value_error((llvm::Twine("Result ") +
1181                                  llvm::Twine(it.index()) + " of operation \"" +
1182                                  name + "\" must be a Sequence of Types (" +
1183                                  err.what() + ")")
1184                                     .str());
1185         }
1186       } else {
1187         throw py::value_error("Unexpected segment spec");
1188       }
1189     }
1190   }
1191 
1192   // Unpack operands.
1193   std::vector<PyValue *> operands;
1194   operands.reserve(operands.size());
1195   if (operandSegmentSpecObj.is_none()) {
1196     // Non-sized operand unpacking.
1197     for (auto it : llvm::enumerate(operandList)) {
1198       try {
1199         operands.push_back(py::cast<PyValue *>(it.value()));
1200         if (!operands.back())
1201           throw py::cast_error();
1202       } catch (py::cast_error &err) {
1203         throw py::value_error((llvm::Twine("Operand ") +
1204                                llvm::Twine(it.index()) + " of operation \"" +
1205                                name + "\" must be a Value (" + err.what() + ")")
1206                                   .str());
1207       }
1208     }
1209   } else {
1210     // Sized operand unpacking.
1211     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1212     if (operandSegmentSpec.size() != operandList.size()) {
1213       throw py::value_error((llvm::Twine("Operation \"") + name +
1214                              "\" requires " +
1215                              llvm::Twine(operandSegmentSpec.size()) +
1216                              "operand segments but was provided " +
1217                              llvm::Twine(operandList.size()))
1218                                 .str());
1219     }
1220     operandSegmentLengths.reserve(operandList.size());
1221     for (auto it :
1222          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1223       int segmentSpec = std::get<1>(it.value());
1224       if (segmentSpec == 1 || segmentSpec == 0) {
1225         // Unpack unary element.
1226         try {
1227           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1228           if (operandValue) {
1229             operands.push_back(operandValue);
1230             operandSegmentLengths.push_back(1);
1231           } else if (segmentSpec == 0) {
1232             // Allowed to be optional.
1233             operandSegmentLengths.push_back(0);
1234           } else {
1235             throw py::cast_error("was None and operand is not optional");
1236           }
1237         } catch (py::cast_error &err) {
1238           throw py::value_error((llvm::Twine("Operand ") +
1239                                  llvm::Twine(it.index()) + " of operation \"" +
1240                                  name + "\" must be a Value (" + err.what() +
1241                                  ")")
1242                                     .str());
1243         }
1244       } else if (segmentSpec == -1) {
1245         // Unpack sequence by appending.
1246         try {
1247           if (std::get<0>(it.value()).is_none()) {
1248             // Treat it as an empty list.
1249             operandSegmentLengths.push_back(0);
1250           } else {
1251             // Unpack the list.
1252             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1253             for (py::object segmentItem : segment) {
1254               operands.push_back(py::cast<PyValue *>(segmentItem));
1255               if (!operands.back()) {
1256                 throw py::cast_error("contained a None item");
1257               }
1258             }
1259             operandSegmentLengths.push_back(segment.size());
1260           }
1261         } catch (std::exception &err) {
1262           // NOTE: Sloppy to be using a catch-all here, but there are at least
1263           // three different unrelated exceptions that can be thrown in the
1264           // above "casts". Just keep the scope above small and catch them all.
1265           throw py::value_error((llvm::Twine("Operand ") +
1266                                  llvm::Twine(it.index()) + " of operation \"" +
1267                                  name + "\" must be a Sequence of Values (" +
1268                                  err.what() + ")")
1269                                     .str());
1270         }
1271       } else {
1272         throw py::value_error("Unexpected segment spec");
1273       }
1274     }
1275   }
1276 
1277   // Merge operand/result segment lengths into attributes if needed.
1278   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1279     // Dup.
1280     if (attributes) {
1281       attributes = py::dict(*attributes);
1282     } else {
1283       attributes = py::dict();
1284     }
1285     if (attributes->contains("result_segment_sizes") ||
1286         attributes->contains("operand_segment_sizes")) {
1287       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1288                             "'operand_segment_sizes' attribute is unsupported. "
1289                             "Use Operation.create for such low-level access.");
1290     }
1291 
1292     // Add result_segment_sizes attribute.
1293     if (!resultSegmentLengths.empty()) {
1294       int64_t size = resultSegmentLengths.size();
1295       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1296           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1297           resultSegmentLengths.size(), resultSegmentLengths.data());
1298       (*attributes)["result_segment_sizes"] =
1299           PyAttribute(context, segmentLengthAttr);
1300     }
1301 
1302     // Add operand_segment_sizes attribute.
1303     if (!operandSegmentLengths.empty()) {
1304       int64_t size = operandSegmentLengths.size();
1305       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1306           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1307           operandSegmentLengths.size(), operandSegmentLengths.data());
1308       (*attributes)["operand_segment_sizes"] =
1309           PyAttribute(context, segmentLengthAttr);
1310     }
1311   }
1312 
1313   // Delegate to create.
1314   return PyOperation::create(std::move(name),
1315                              /*results=*/std::move(resultTypes),
1316                              /*operands=*/std::move(operands),
1317                              /*attributes=*/std::move(attributes),
1318                              /*successors=*/std::move(successors),
1319                              /*regions=*/*regions, location, maybeIp);
1320 }
1321 
1322 PyOpView::PyOpView(py::object operationObject)
1323     // Casting through the PyOperationBase base-class and then back to the
1324     // Operation lets us accept any PyOperationBase subclass.
1325     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1326       operationObject(operation.getRef().getObject()) {}
1327 
1328 py::object PyOpView::createRawSubclass(py::object userClass) {
1329   // This is... a little gross. The typical pattern is to have a pure python
1330   // class that extends OpView like:
1331   //   class AddFOp(_cext.ir.OpView):
1332   //     def __init__(self, loc, lhs, rhs):
1333   //       operation = loc.context.create_operation(
1334   //           "addf", lhs, rhs, results=[lhs.type])
1335   //       super().__init__(operation)
1336   //
1337   // I.e. The goal of the user facing type is to provide a nice constructor
1338   // that has complete freedom for the op under construction. This is at odds
1339   // with our other desire to sometimes create this object by just passing an
1340   // operation (to initialize the base class). We could do *arg and **kwargs
1341   // munging to try to make it work, but instead, we synthesize a new class
1342   // on the fly which extends this user class (AddFOp in this example) and
1343   // *give it* the base class's __init__ method, thus bypassing the
1344   // intermediate subclass's __init__ method entirely. While slightly,
1345   // underhanded, this is safe/legal because the type hierarchy has not changed
1346   // (we just added a new leaf) and we aren't mucking around with __new__.
1347   // Typically, this new class will be stored on the original as "_Raw" and will
1348   // be used for casts and other things that need a variant of the class that
1349   // is initialized purely from an operation.
1350   py::object parentMetaclass =
1351       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1352   py::dict attributes;
1353   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1354   // now.
1355   //   auto opViewType = py::type::of<PyOpView>();
1356   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1357   attributes["__init__"] = opViewType.attr("__init__");
1358   py::str origName = userClass.attr("__name__");
1359   py::str newName = py::str("_") + origName;
1360   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1361 }
1362 
1363 //------------------------------------------------------------------------------
1364 // PyInsertionPoint.
1365 //------------------------------------------------------------------------------
1366 
1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1368 
1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1370     : refOperation(beforeOperationBase.getOperation().getRef()),
1371       block((*refOperation)->getBlock()) {}
1372 
1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1374   PyOperation &operation = operationBase.getOperation();
1375   if (operation.isAttached())
1376     throw SetPyError(PyExc_ValueError,
1377                      "Attempt to insert operation that is already attached");
1378   block.getParentOperation()->checkValid();
1379   MlirOperation beforeOp = {nullptr};
1380   if (refOperation) {
1381     // Insert before operation.
1382     (*refOperation)->checkValid();
1383     beforeOp = (*refOperation)->get();
1384   } else {
1385     // Insert at end (before null) is only valid if the block does not
1386     // already end in a known terminator (violating this will cause assertion
1387     // failures later).
1388     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1389       throw py::index_error("Cannot insert operation at the end of a block "
1390                             "that already has a terminator. Did you mean to "
1391                             "use 'InsertionPoint.at_block_terminator(block)' "
1392                             "versus 'InsertionPoint(block)'?");
1393     }
1394   }
1395   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1396   operation.setAttached();
1397 }
1398 
1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1400   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1401   if (mlirOperationIsNull(firstOp)) {
1402     // Just insert at end.
1403     return PyInsertionPoint(block);
1404   }
1405 
1406   // Insert before first op.
1407   PyOperationRef firstOpRef = PyOperation::forOperation(
1408       block.getParentOperation()->getContext(), firstOp);
1409   return PyInsertionPoint{block, std::move(firstOpRef)};
1410 }
1411 
1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1413   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1414   if (mlirOperationIsNull(terminator))
1415     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1416   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1417       block.getParentOperation()->getContext(), terminator);
1418   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1419 }
1420 
1421 py::object PyInsertionPoint::contextEnter() {
1422   return PyThreadContextEntry::pushInsertionPoint(*this);
1423 }
1424 
1425 void PyInsertionPoint::contextExit(pybind11::object excType,
1426                                    pybind11::object excVal,
1427                                    pybind11::object excTb) {
1428   PyThreadContextEntry::popInsertionPoint(*this);
1429 }
1430 
1431 //------------------------------------------------------------------------------
1432 // PyAttribute.
1433 //------------------------------------------------------------------------------
1434 
1435 bool PyAttribute::operator==(const PyAttribute &other) {
1436   return mlirAttributeEqual(attr, other.attr);
1437 }
1438 
1439 py::object PyAttribute::getCapsule() {
1440   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1441 }
1442 
1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1444   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1445   if (mlirAttributeIsNull(rawAttr))
1446     throw py::error_already_set();
1447   return PyAttribute(
1448       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1449 }
1450 
1451 //------------------------------------------------------------------------------
1452 // PyNamedAttribute.
1453 //------------------------------------------------------------------------------
1454 
1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1456     : ownedName(new std::string(std::move(ownedName))) {
1457   namedAttr = mlirNamedAttributeGet(
1458       mlirIdentifierGet(mlirAttributeGetContext(attr),
1459                         toMlirStringRef(*this->ownedName)),
1460       attr);
1461 }
1462 
1463 //------------------------------------------------------------------------------
1464 // PyType.
1465 //------------------------------------------------------------------------------
1466 
1467 bool PyType::operator==(const PyType &other) {
1468   return mlirTypeEqual(type, other.type);
1469 }
1470 
1471 py::object PyType::getCapsule() {
1472   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1473 }
1474 
1475 PyType PyType::createFromCapsule(py::object capsule) {
1476   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1477   if (mlirTypeIsNull(rawType))
1478     throw py::error_already_set();
1479   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1480                 rawType);
1481 }
1482 
1483 //------------------------------------------------------------------------------
1484 // PyValue and subclases.
1485 //------------------------------------------------------------------------------
1486 
1487 pybind11::object PyValue::getCapsule() {
1488   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1489 }
1490 
1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1492   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1493   if (mlirValueIsNull(value))
1494     throw py::error_already_set();
1495   MlirOperation owner;
1496   if (mlirValueIsAOpResult(value))
1497     owner = mlirOpResultGetOwner(value);
1498   if (mlirValueIsABlockArgument(value))
1499     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1500   if (mlirOperationIsNull(owner))
1501     throw py::error_already_set();
1502   MlirContext ctx = mlirOperationGetContext(owner);
1503   PyOperationRef ownerRef =
1504       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1505   return PyValue(ownerRef, value);
1506 }
1507 
1508 namespace {
1509 /// CRTP base class for Python MLIR values that subclass Value and should be
1510 /// castable from it. The value hierarchy is one level deep and is not supposed
1511 /// to accommodate other levels unless core MLIR changes.
1512 template <typename DerivedTy>
1513 class PyConcreteValue : public PyValue {
1514 public:
1515   // Derived classes must define statics for:
1516   //   IsAFunctionTy isaFunction
1517   //   const char *pyClassName
1518   // and redefine bindDerived.
1519   using ClassTy = py::class_<DerivedTy, PyValue>;
1520   using IsAFunctionTy = bool (*)(MlirValue);
1521 
1522   PyConcreteValue() = default;
1523   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1524       : PyValue(operationRef, value) {}
1525   PyConcreteValue(PyValue &orig)
1526       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1527 
1528   /// Attempts to cast the original value to the derived type and throws on
1529   /// type mismatches.
1530   static MlirValue castFrom(PyValue &orig) {
1531     if (!DerivedTy::isaFunction(orig.get())) {
1532       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1533       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1534                                              DerivedTy::pyClassName +
1535                                              " (from " + origRepr + ")");
1536     }
1537     return orig.get();
1538   }
1539 
1540   /// Binds the Python module objects to functions of this class.
1541   static void bind(py::module &m) {
1542     auto cls = ClassTy(m, DerivedTy::pyClassName);
1543     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1544     DerivedTy::bindDerived(cls);
1545   }
1546 
1547   /// Implemented by derived classes to add methods to the Python subclass.
1548   static void bindDerived(ClassTy &m) {}
1549 };
1550 
1551 /// Python wrapper for MlirBlockArgument.
1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1553 public:
1554   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1555   static constexpr const char *pyClassName = "BlockArgument";
1556   using PyConcreteValue::PyConcreteValue;
1557 
1558   static void bindDerived(ClassTy &c) {
1559     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1560       return PyBlock(self.getParentOperation(),
1561                      mlirBlockArgumentGetOwner(self.get()));
1562     });
1563     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1564       return mlirBlockArgumentGetArgNumber(self.get());
1565     });
1566     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1567       return mlirBlockArgumentSetType(self.get(), type);
1568     });
1569   }
1570 };
1571 
1572 /// Python wrapper for MlirOpResult.
1573 class PyOpResult : public PyConcreteValue<PyOpResult> {
1574 public:
1575   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1576   static constexpr const char *pyClassName = "OpResult";
1577   using PyConcreteValue::PyConcreteValue;
1578 
1579   static void bindDerived(ClassTy &c) {
1580     c.def_property_readonly("owner", [](PyOpResult &self) {
1581       assert(
1582           mlirOperationEqual(self.getParentOperation()->get(),
1583                              mlirOpResultGetOwner(self.get())) &&
1584           "expected the owner of the value in Python to match that in the IR");
1585       return self.getParentOperation().getObject();
1586     });
1587     c.def_property_readonly("result_number", [](PyOpResult &self) {
1588       return mlirOpResultGetResultNumber(self.get());
1589     });
1590   }
1591 };
1592 
1593 /// A list of block arguments. Internally, these are stored as consecutive
1594 /// elements, random access is cheap. The argument list is associated with the
1595 /// operation that contains the block (detached blocks are not allowed in
1596 /// Python bindings) and extends its lifetime.
1597 class PyBlockArgumentList {
1598 public:
1599   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1600       : operation(std::move(operation)), block(block) {}
1601 
1602   /// Returns the length of the block argument list.
1603   intptr_t dunderLen() {
1604     operation->checkValid();
1605     return mlirBlockGetNumArguments(block);
1606   }
1607 
1608   /// Returns `index`-th element of the block argument list.
1609   PyBlockArgument dunderGetItem(intptr_t index) {
1610     if (index < 0 || index >= dunderLen()) {
1611       throw SetPyError(PyExc_IndexError,
1612                        "attempt to access out of bounds region");
1613     }
1614     PyValue value(operation, mlirBlockGetArgument(block, index));
1615     return PyBlockArgument(value);
1616   }
1617 
1618   /// Defines a Python class in the bindings.
1619   static void bind(py::module &m) {
1620     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1621         .def("__len__", &PyBlockArgumentList::dunderLen)
1622         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1623   }
1624 
1625 private:
1626   PyOperationRef operation;
1627   MlirBlock block;
1628 };
1629 
1630 /// A list of operation operands. Internally, these are stored as consecutive
1631 /// elements, random access is cheap. The result list is associated with the
1632 /// operation whose results these are, and extends the lifetime of this
1633 /// operation.
1634 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1635 public:
1636   static constexpr const char *pyClassName = "OpOperandList";
1637 
1638   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1639                   intptr_t length = -1, intptr_t step = 1)
1640       : Sliceable(startIndex,
1641                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1642                                : length,
1643                   step),
1644         operation(operation) {}
1645 
1646   intptr_t getNumElements() {
1647     operation->checkValid();
1648     return mlirOperationGetNumOperands(operation->get());
1649   }
1650 
1651   PyValue getElement(intptr_t pos) {
1652     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1653     MlirOperation owner;
1654     if (mlirValueIsAOpResult(operand))
1655       owner = mlirOpResultGetOwner(operand);
1656     else if (mlirValueIsABlockArgument(operand))
1657       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1658     else
1659       assert(false && "Value must be an block arg or op result.");
1660     PyOperationRef pyOwner =
1661         PyOperation::forOperation(operation->getContext(), owner);
1662     return PyValue(pyOwner, operand);
1663   }
1664 
1665   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1666     return PyOpOperandList(operation, startIndex, length, step);
1667   }
1668 
1669   void dunderSetItem(intptr_t index, PyValue value) {
1670     index = wrapIndex(index);
1671     mlirOperationSetOperand(operation->get(), index, value.get());
1672   }
1673 
1674   static void bindDerived(ClassTy &c) {
1675     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1676   }
1677 
1678 private:
1679   PyOperationRef operation;
1680 };
1681 
1682 /// A list of operation results. Internally, these are stored as consecutive
1683 /// elements, random access is cheap. The result list is associated with the
1684 /// operation whose results these are, and extends the lifetime of this
1685 /// operation.
1686 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1687 public:
1688   static constexpr const char *pyClassName = "OpResultList";
1689 
1690   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1691                  intptr_t length = -1, intptr_t step = 1)
1692       : Sliceable(startIndex,
1693                   length == -1 ? mlirOperationGetNumResults(operation->get())
1694                                : length,
1695                   step),
1696         operation(operation) {}
1697 
1698   intptr_t getNumElements() {
1699     operation->checkValid();
1700     return mlirOperationGetNumResults(operation->get());
1701   }
1702 
1703   PyOpResult getElement(intptr_t index) {
1704     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1705     return PyOpResult(value);
1706   }
1707 
1708   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1709     return PyOpResultList(operation, startIndex, length, step);
1710   }
1711 
1712 private:
1713   PyOperationRef operation;
1714 };
1715 
1716 /// A list of operation attributes. Can be indexed by name, producing
1717 /// attributes, or by index, producing named attributes.
1718 class PyOpAttributeMap {
1719 public:
1720   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1721 
1722   PyAttribute dunderGetItemNamed(const std::string &name) {
1723     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1724                                                          toMlirStringRef(name));
1725     if (mlirAttributeIsNull(attr)) {
1726       throw SetPyError(PyExc_KeyError,
1727                        "attempt to access a non-existent attribute");
1728     }
1729     return PyAttribute(operation->getContext(), attr);
1730   }
1731 
1732   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1733     if (index < 0 || index >= dunderLen()) {
1734       throw SetPyError(PyExc_IndexError,
1735                        "attempt to access out of bounds attribute");
1736     }
1737     MlirNamedAttribute namedAttr =
1738         mlirOperationGetAttribute(operation->get(), index);
1739     return PyNamedAttribute(
1740         namedAttr.attribute,
1741         std::string(mlirIdentifierStr(namedAttr.name).data));
1742   }
1743 
1744   void dunderSetItem(const std::string &name, PyAttribute attr) {
1745     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1746                                     attr);
1747   }
1748 
1749   void dunderDelItem(const std::string &name) {
1750     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1751                                                      toMlirStringRef(name));
1752     if (!removed)
1753       throw SetPyError(PyExc_KeyError,
1754                        "attempt to delete a non-existent attribute");
1755   }
1756 
1757   intptr_t dunderLen() {
1758     return mlirOperationGetNumAttributes(operation->get());
1759   }
1760 
1761   bool dunderContains(const std::string &name) {
1762     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1763         operation->get(), toMlirStringRef(name)));
1764   }
1765 
1766   static void bind(py::module &m) {
1767     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1768         .def("__contains__", &PyOpAttributeMap::dunderContains)
1769         .def("__len__", &PyOpAttributeMap::dunderLen)
1770         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1771         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1772         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1773         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1774   }
1775 
1776 private:
1777   PyOperationRef operation;
1778 };
1779 
1780 } // end namespace
1781 
1782 //------------------------------------------------------------------------------
1783 // Populates the core exports of the 'ir' submodule.
1784 //------------------------------------------------------------------------------
1785 
1786 void mlir::python::populateIRCore(py::module &m) {
1787   //----------------------------------------------------------------------------
1788   // Mapping of MlirContext.
1789   //----------------------------------------------------------------------------
1790   py::class_<PyMlirContext>(m, "Context")
1791       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1792       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1793       .def("_get_context_again",
1794            [](PyMlirContext &self) {
1795              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1796              return ref.releaseObject();
1797            })
1798       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1799       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1800       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1801                              &PyMlirContext::getCapsule)
1802       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1803       .def("__enter__", &PyMlirContext::contextEnter)
1804       .def("__exit__", &PyMlirContext::contextExit)
1805       .def_property_readonly_static(
1806           "current",
1807           [](py::object & /*class*/) {
1808             auto *context = PyThreadContextEntry::getDefaultContext();
1809             if (!context)
1810               throw SetPyError(PyExc_ValueError, "No current Context");
1811             return context;
1812           },
1813           "Gets the Context bound to the current thread or raises ValueError")
1814       .def_property_readonly(
1815           "dialects",
1816           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1817           "Gets a container for accessing dialects by name")
1818       .def_property_readonly(
1819           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1820           "Alias for 'dialect'")
1821       .def(
1822           "get_dialect_descriptor",
1823           [=](PyMlirContext &self, std::string &name) {
1824             MlirDialect dialect = mlirContextGetOrLoadDialect(
1825                 self.get(), {name.data(), name.size()});
1826             if (mlirDialectIsNull(dialect)) {
1827               throw SetPyError(PyExc_ValueError,
1828                                Twine("Dialect '") + name + "' not found");
1829             }
1830             return PyDialectDescriptor(self.getRef(), dialect);
1831           },
1832           "Gets or loads a dialect by name, returning its descriptor object")
1833       .def_property(
1834           "allow_unregistered_dialects",
1835           [](PyMlirContext &self) -> bool {
1836             return mlirContextGetAllowUnregisteredDialects(self.get());
1837           },
1838           [](PyMlirContext &self, bool value) {
1839             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1840           })
1841       .def("enable_multithreading",
1842            [](PyMlirContext &self, bool enable) {
1843              mlirContextEnableMultithreading(self.get(), enable);
1844            })
1845       .def("is_registered_operation",
1846            [](PyMlirContext &self, std::string &name) {
1847              return mlirContextIsRegisteredOperation(
1848                  self.get(), MlirStringRef{name.data(), name.size()});
1849            });
1850 
1851   //----------------------------------------------------------------------------
1852   // Mapping of PyDialectDescriptor
1853   //----------------------------------------------------------------------------
1854   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1855       .def_property_readonly("namespace",
1856                              [](PyDialectDescriptor &self) {
1857                                MlirStringRef ns =
1858                                    mlirDialectGetNamespace(self.get());
1859                                return py::str(ns.data, ns.length);
1860                              })
1861       .def("__repr__", [](PyDialectDescriptor &self) {
1862         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1863         std::string repr("<DialectDescriptor ");
1864         repr.append(ns.data, ns.length);
1865         repr.append(">");
1866         return repr;
1867       });
1868 
1869   //----------------------------------------------------------------------------
1870   // Mapping of PyDialects
1871   //----------------------------------------------------------------------------
1872   py::class_<PyDialects>(m, "Dialects")
1873       .def("__getitem__",
1874            [=](PyDialects &self, std::string keyName) {
1875              MlirDialect dialect =
1876                  self.getDialectForKey(keyName, /*attrError=*/false);
1877              py::object descriptor =
1878                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1879              return createCustomDialectWrapper(keyName, std::move(descriptor));
1880            })
1881       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1882         MlirDialect dialect =
1883             self.getDialectForKey(attrName, /*attrError=*/true);
1884         py::object descriptor =
1885             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1886         return createCustomDialectWrapper(attrName, std::move(descriptor));
1887       });
1888 
1889   //----------------------------------------------------------------------------
1890   // Mapping of PyDialect
1891   //----------------------------------------------------------------------------
1892   py::class_<PyDialect>(m, "Dialect")
1893       .def(py::init<py::object>(), "descriptor")
1894       .def_property_readonly(
1895           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1896       .def("__repr__", [](py::object self) {
1897         auto clazz = self.attr("__class__");
1898         return py::str("<Dialect ") +
1899                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1900                clazz.attr("__module__") + py::str(".") +
1901                clazz.attr("__name__") + py::str(")>");
1902       });
1903 
1904   //----------------------------------------------------------------------------
1905   // Mapping of Location
1906   //----------------------------------------------------------------------------
1907   py::class_<PyLocation>(m, "Location")
1908       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1909       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1910       .def("__enter__", &PyLocation::contextEnter)
1911       .def("__exit__", &PyLocation::contextExit)
1912       .def("__eq__",
1913            [](PyLocation &self, PyLocation &other) -> bool {
1914              return mlirLocationEqual(self, other);
1915            })
1916       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1917       .def_property_readonly_static(
1918           "current",
1919           [](py::object & /*class*/) {
1920             auto *loc = PyThreadContextEntry::getDefaultLocation();
1921             if (!loc)
1922               throw SetPyError(PyExc_ValueError, "No current Location");
1923             return loc;
1924           },
1925           "Gets the Location bound to the current thread or raises ValueError")
1926       .def_static(
1927           "unknown",
1928           [](DefaultingPyMlirContext context) {
1929             return PyLocation(context->getRef(),
1930                               mlirLocationUnknownGet(context->get()));
1931           },
1932           py::arg("context") = py::none(),
1933           "Gets a Location representing an unknown location")
1934       .def_static(
1935           "file",
1936           [](std::string filename, int line, int col,
1937              DefaultingPyMlirContext context) {
1938             return PyLocation(
1939                 context->getRef(),
1940                 mlirLocationFileLineColGet(
1941                     context->get(), toMlirStringRef(filename), line, col));
1942           },
1943           py::arg("filename"), py::arg("line"), py::arg("col"),
1944           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1945       .def_property_readonly(
1946           "context",
1947           [](PyLocation &self) { return self.getContext().getObject(); },
1948           "Context that owns the Location")
1949       .def("__repr__", [](PyLocation &self) {
1950         PyPrintAccumulator printAccum;
1951         mlirLocationPrint(self, printAccum.getCallback(),
1952                           printAccum.getUserData());
1953         return printAccum.join();
1954       });
1955 
1956   //----------------------------------------------------------------------------
1957   // Mapping of Module
1958   //----------------------------------------------------------------------------
1959   py::class_<PyModule>(m, "Module")
1960       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1961       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1962       .def_static(
1963           "parse",
1964           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1965             MlirModule module = mlirModuleCreateParse(
1966                 context->get(), toMlirStringRef(moduleAsm));
1967             // TODO: Rework error reporting once diagnostic engine is exposed
1968             // in C API.
1969             if (mlirModuleIsNull(module)) {
1970               throw SetPyError(
1971                   PyExc_ValueError,
1972                   "Unable to parse module assembly (see diagnostics)");
1973             }
1974             return PyModule::forModule(module).releaseObject();
1975           },
1976           py::arg("asm"), py::arg("context") = py::none(),
1977           kModuleParseDocstring)
1978       .def_static(
1979           "create",
1980           [](DefaultingPyLocation loc) {
1981             MlirModule module = mlirModuleCreateEmpty(loc);
1982             return PyModule::forModule(module).releaseObject();
1983           },
1984           py::arg("loc") = py::none(), "Creates an empty module")
1985       .def_property_readonly(
1986           "context",
1987           [](PyModule &self) { return self.getContext().getObject(); },
1988           "Context that created the Module")
1989       .def_property_readonly(
1990           "operation",
1991           [](PyModule &self) {
1992             return PyOperation::forOperation(self.getContext(),
1993                                              mlirModuleGetOperation(self.get()),
1994                                              self.getRef().releaseObject())
1995                 .releaseObject();
1996           },
1997           "Accesses the module as an operation")
1998       .def_property_readonly(
1999           "body",
2000           [](PyModule &self) {
2001             PyOperationRef module_op = PyOperation::forOperation(
2002                 self.getContext(), mlirModuleGetOperation(self.get()),
2003                 self.getRef().releaseObject());
2004             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2005             return returnBlock;
2006           },
2007           "Return the block for this module")
2008       .def(
2009           "dump",
2010           [](PyModule &self) {
2011             mlirOperationDump(mlirModuleGetOperation(self.get()));
2012           },
2013           kDumpDocstring)
2014       .def(
2015           "__str__",
2016           [](PyModule &self) {
2017             MlirOperation operation = mlirModuleGetOperation(self.get());
2018             PyPrintAccumulator printAccum;
2019             mlirOperationPrint(operation, printAccum.getCallback(),
2020                                printAccum.getUserData());
2021             return printAccum.join();
2022           },
2023           kOperationStrDunderDocstring);
2024 
2025   //----------------------------------------------------------------------------
2026   // Mapping of Operation.
2027   //----------------------------------------------------------------------------
2028   py::class_<PyOperationBase>(m, "_OperationBase")
2029       .def("__eq__",
2030            [](PyOperationBase &self, PyOperationBase &other) {
2031              return &self.getOperation() == &other.getOperation();
2032            })
2033       .def("__eq__",
2034            [](PyOperationBase &self, py::object other) { return false; })
2035       .def_property_readonly("attributes",
2036                              [](PyOperationBase &self) {
2037                                return PyOpAttributeMap(
2038                                    self.getOperation().getRef());
2039                              })
2040       .def_property_readonly("operands",
2041                              [](PyOperationBase &self) {
2042                                return PyOpOperandList(
2043                                    self.getOperation().getRef());
2044                              })
2045       .def_property_readonly("regions",
2046                              [](PyOperationBase &self) {
2047                                return PyRegionList(
2048                                    self.getOperation().getRef());
2049                              })
2050       .def_property_readonly(
2051           "results",
2052           [](PyOperationBase &self) {
2053             return PyOpResultList(self.getOperation().getRef());
2054           },
2055           "Returns the list of Operation results.")
2056       .def_property_readonly(
2057           "result",
2058           [](PyOperationBase &self) {
2059             auto &operation = self.getOperation();
2060             auto numResults = mlirOperationGetNumResults(operation);
2061             if (numResults != 1) {
2062               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2063               throw SetPyError(
2064                   PyExc_ValueError,
2065                   Twine("Cannot call .result on operation ") +
2066                       StringRef(name.data, name.length) + " which has " +
2067                       Twine(numResults) +
2068                       " results (it is only valid for operations with a "
2069                       "single result)");
2070             }
2071             return PyOpResult(operation.getRef(),
2072                               mlirOperationGetResult(operation, 0));
2073           },
2074           "Shortcut to get an op result if it has only one (throws an error "
2075           "otherwise).")
2076       .def("__iter__",
2077            [](PyOperationBase &self) {
2078              return PyRegionIterator(self.getOperation().getRef());
2079            })
2080       .def(
2081           "__str__",
2082           [](PyOperationBase &self) {
2083             return self.getAsm(/*binary=*/false,
2084                                /*largeElementsLimit=*/llvm::None,
2085                                /*enableDebugInfo=*/false,
2086                                /*prettyDebugInfo=*/false,
2087                                /*printGenericOpForm=*/false,
2088                                /*useLocalScope=*/false);
2089           },
2090           "Returns the assembly form of the operation.")
2091       .def("print", &PyOperationBase::print,
2092            // Careful: Lots of arguments must match up with print method.
2093            py::arg("file") = py::none(), py::arg("binary") = false,
2094            py::arg("large_elements_limit") = py::none(),
2095            py::arg("enable_debug_info") = false,
2096            py::arg("pretty_debug_info") = false,
2097            py::arg("print_generic_op_form") = false,
2098            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2099       .def("get_asm", &PyOperationBase::getAsm,
2100            // Careful: Lots of arguments must match up with get_asm method.
2101            py::arg("binary") = false,
2102            py::arg("large_elements_limit") = py::none(),
2103            py::arg("enable_debug_info") = false,
2104            py::arg("pretty_debug_info") = false,
2105            py::arg("print_generic_op_form") = false,
2106            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2107       .def(
2108           "verify",
2109           [](PyOperationBase &self) {
2110             return mlirOperationVerify(self.getOperation());
2111           },
2112           "Verify the operation and return true if it passes, false if it "
2113           "fails.");
2114 
2115   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2116       .def_static("create", &PyOperation::create, py::arg("name"),
2117                   py::arg("results") = py::none(),
2118                   py::arg("operands") = py::none(),
2119                   py::arg("attributes") = py::none(),
2120                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2121                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2122                   kOperationCreateDocstring)
2123       .def_property_readonly("parent",
2124                              [](PyOperation &self) {
2125                                return self.getParentOperation().getObject();
2126                              })
2127       .def("erase", &PyOperation::erase)
2128       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2129                              &PyOperation::getCapsule)
2130       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2131       .def_property_readonly("name",
2132                              [](PyOperation &self) {
2133                                self.checkValid();
2134                                MlirOperation operation = self.get();
2135                                MlirStringRef name = mlirIdentifierStr(
2136                                    mlirOperationGetName(operation));
2137                                return py::str(name.data, name.length);
2138                              })
2139       .def_property_readonly(
2140           "context",
2141           [](PyOperation &self) {
2142             self.checkValid();
2143             return self.getContext().getObject();
2144           },
2145           "Context that owns the Operation")
2146       .def_property_readonly("opview", &PyOperation::createOpView);
2147 
2148   auto opViewClass =
2149       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2150           .def(py::init<py::object>())
2151           .def_property_readonly("operation", &PyOpView::getOperationObject)
2152           .def_property_readonly(
2153               "context",
2154               [](PyOpView &self) {
2155                 return self.getOperation().getContext().getObject();
2156               },
2157               "Context that owns the Operation")
2158           .def("__str__", [](PyOpView &self) {
2159             return py::str(self.getOperationObject());
2160           });
2161   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2162   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2163   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2164   opViewClass.attr("build_generic") = classmethod(
2165       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2166       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2167       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2168       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2169       "Builds a specific, generated OpView based on class level attributes.");
2170 
2171   //----------------------------------------------------------------------------
2172   // Mapping of PyRegion.
2173   //----------------------------------------------------------------------------
2174   py::class_<PyRegion>(m, "Region")
2175       .def_property_readonly(
2176           "blocks",
2177           [](PyRegion &self) {
2178             return PyBlockList(self.getParentOperation(), self.get());
2179           },
2180           "Returns a forward-optimized sequence of blocks.")
2181       .def(
2182           "__iter__",
2183           [](PyRegion &self) {
2184             self.checkValid();
2185             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2186             return PyBlockIterator(self.getParentOperation(), firstBlock);
2187           },
2188           "Iterates over blocks in the region.")
2189       .def("__eq__",
2190            [](PyRegion &self, PyRegion &other) {
2191              return self.get().ptr == other.get().ptr;
2192            })
2193       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2194 
2195   //----------------------------------------------------------------------------
2196   // Mapping of PyBlock.
2197   //----------------------------------------------------------------------------
2198   py::class_<PyBlock>(m, "Block")
2199       .def_property_readonly(
2200           "arguments",
2201           [](PyBlock &self) {
2202             return PyBlockArgumentList(self.getParentOperation(), self.get());
2203           },
2204           "Returns a list of block arguments.")
2205       .def_property_readonly(
2206           "operations",
2207           [](PyBlock &self) {
2208             return PyOperationList(self.getParentOperation(), self.get());
2209           },
2210           "Returns a forward-optimized sequence of operations.")
2211       .def(
2212           "__iter__",
2213           [](PyBlock &self) {
2214             self.checkValid();
2215             MlirOperation firstOperation =
2216                 mlirBlockGetFirstOperation(self.get());
2217             return PyOperationIterator(self.getParentOperation(),
2218                                        firstOperation);
2219           },
2220           "Iterates over operations in the block.")
2221       .def("__eq__",
2222            [](PyBlock &self, PyBlock &other) {
2223              return self.get().ptr == other.get().ptr;
2224            })
2225       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2226       .def(
2227           "__str__",
2228           [](PyBlock &self) {
2229             self.checkValid();
2230             PyPrintAccumulator printAccum;
2231             mlirBlockPrint(self.get(), printAccum.getCallback(),
2232                            printAccum.getUserData());
2233             return printAccum.join();
2234           },
2235           "Returns the assembly form of the block.");
2236 
2237   //----------------------------------------------------------------------------
2238   // Mapping of PyInsertionPoint.
2239   //----------------------------------------------------------------------------
2240 
2241   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2242       .def(py::init<PyBlock &>(), py::arg("block"),
2243            "Inserts after the last operation but still inside the block.")
2244       .def("__enter__", &PyInsertionPoint::contextEnter)
2245       .def("__exit__", &PyInsertionPoint::contextExit)
2246       .def_property_readonly_static(
2247           "current",
2248           [](py::object & /*class*/) {
2249             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2250             if (!ip)
2251               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2252             return ip;
2253           },
2254           "Gets the InsertionPoint bound to the current thread or raises "
2255           "ValueError if none has been set")
2256       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2257            "Inserts before a referenced operation.")
2258       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2259                   py::arg("block"), "Inserts at the beginning of the block.")
2260       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2261                   py::arg("block"), "Inserts before the block terminator.")
2262       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2263            "Inserts an operation.");
2264 
2265   //----------------------------------------------------------------------------
2266   // Mapping of PyAttribute.
2267   //----------------------------------------------------------------------------
2268   py::class_<PyAttribute>(m, "Attribute")
2269       // Delegate to the PyAttribute copy constructor, which will also lifetime
2270       // extend the backing context which owns the MlirAttribute.
2271       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2272            "Casts the passed attribute to the generic Attribute")
2273       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2274                              &PyAttribute::getCapsule)
2275       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2276       .def_static(
2277           "parse",
2278           [](std::string attrSpec, DefaultingPyMlirContext context) {
2279             MlirAttribute type = mlirAttributeParseGet(
2280                 context->get(), toMlirStringRef(attrSpec));
2281             // TODO: Rework error reporting once diagnostic engine is exposed
2282             // in C API.
2283             if (mlirAttributeIsNull(type)) {
2284               throw SetPyError(PyExc_ValueError,
2285                                Twine("Unable to parse attribute: '") +
2286                                    attrSpec + "'");
2287             }
2288             return PyAttribute(context->getRef(), type);
2289           },
2290           py::arg("asm"), py::arg("context") = py::none(),
2291           "Parses an attribute from an assembly form")
2292       .def_property_readonly(
2293           "context",
2294           [](PyAttribute &self) { return self.getContext().getObject(); },
2295           "Context that owns the Attribute")
2296       .def_property_readonly("type",
2297                              [](PyAttribute &self) {
2298                                return PyType(self.getContext()->getRef(),
2299                                              mlirAttributeGetType(self));
2300                              })
2301       .def(
2302           "get_named",
2303           [](PyAttribute &self, std::string name) {
2304             return PyNamedAttribute(self, std::move(name));
2305           },
2306           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2307       .def("__eq__",
2308            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2309       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2310       .def(
2311           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2312           kDumpDocstring)
2313       .def(
2314           "__str__",
2315           [](PyAttribute &self) {
2316             PyPrintAccumulator printAccum;
2317             mlirAttributePrint(self, printAccum.getCallback(),
2318                                printAccum.getUserData());
2319             return printAccum.join();
2320           },
2321           "Returns the assembly form of the Attribute.")
2322       .def("__repr__", [](PyAttribute &self) {
2323         // Generally, assembly formats are not printed for __repr__ because
2324         // this can cause exceptionally long debug output and exceptions.
2325         // However, attribute values are generally considered useful and are
2326         // printed. This may need to be re-evaluated if debug dumps end up
2327         // being excessive.
2328         PyPrintAccumulator printAccum;
2329         printAccum.parts.append("Attribute(");
2330         mlirAttributePrint(self, printAccum.getCallback(),
2331                            printAccum.getUserData());
2332         printAccum.parts.append(")");
2333         return printAccum.join();
2334       });
2335 
2336   //----------------------------------------------------------------------------
2337   // Mapping of PyNamedAttribute
2338   //----------------------------------------------------------------------------
2339   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2340       .def("__repr__",
2341            [](PyNamedAttribute &self) {
2342              PyPrintAccumulator printAccum;
2343              printAccum.parts.append("NamedAttribute(");
2344              printAccum.parts.append(
2345                  mlirIdentifierStr(self.namedAttr.name).data);
2346              printAccum.parts.append("=");
2347              mlirAttributePrint(self.namedAttr.attribute,
2348                                 printAccum.getCallback(),
2349                                 printAccum.getUserData());
2350              printAccum.parts.append(")");
2351              return printAccum.join();
2352            })
2353       .def_property_readonly(
2354           "name",
2355           [](PyNamedAttribute &self) {
2356             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2357                            mlirIdentifierStr(self.namedAttr.name).length);
2358           },
2359           "The name of the NamedAttribute binding")
2360       .def_property_readonly(
2361           "attr",
2362           [](PyNamedAttribute &self) {
2363             // TODO: When named attribute is removed/refactored, also remove
2364             // this constructor (it does an inefficient table lookup).
2365             auto contextRef = PyMlirContext::forContext(
2366                 mlirAttributeGetContext(self.namedAttr.attribute));
2367             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2368           },
2369           py::keep_alive<0, 1>(),
2370           "The underlying generic attribute of the NamedAttribute binding");
2371 
2372   //----------------------------------------------------------------------------
2373   // Mapping of PyType.
2374   //----------------------------------------------------------------------------
2375   py::class_<PyType>(m, "Type")
2376       // Delegate to the PyType copy constructor, which will also lifetime
2377       // extend the backing context which owns the MlirType.
2378       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2379            "Casts the passed type to the generic Type")
2380       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2381       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2382       .def_static(
2383           "parse",
2384           [](std::string typeSpec, DefaultingPyMlirContext context) {
2385             MlirType type =
2386                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2387             // TODO: Rework error reporting once diagnostic engine is exposed
2388             // in C API.
2389             if (mlirTypeIsNull(type)) {
2390               throw SetPyError(PyExc_ValueError,
2391                                Twine("Unable to parse type: '") + typeSpec +
2392                                    "'");
2393             }
2394             return PyType(context->getRef(), type);
2395           },
2396           py::arg("asm"), py::arg("context") = py::none(),
2397           kContextParseTypeDocstring)
2398       .def_property_readonly(
2399           "context", [](PyType &self) { return self.getContext().getObject(); },
2400           "Context that owns the Type")
2401       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2402       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2403       .def(
2404           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2405       .def(
2406           "__str__",
2407           [](PyType &self) {
2408             PyPrintAccumulator printAccum;
2409             mlirTypePrint(self, printAccum.getCallback(),
2410                           printAccum.getUserData());
2411             return printAccum.join();
2412           },
2413           "Returns the assembly form of the type.")
2414       .def("__repr__", [](PyType &self) {
2415         // Generally, assembly formats are not printed for __repr__ because
2416         // this can cause exceptionally long debug output and exceptions.
2417         // However, types are an exception as they typically have compact
2418         // assembly forms and printing them is useful.
2419         PyPrintAccumulator printAccum;
2420         printAccum.parts.append("Type(");
2421         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2422         printAccum.parts.append(")");
2423         return printAccum.join();
2424       });
2425 
2426   //----------------------------------------------------------------------------
2427   // Mapping of Value.
2428   //----------------------------------------------------------------------------
2429   py::class_<PyValue>(m, "Value")
2430       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2431       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2432       .def_property_readonly(
2433           "context",
2434           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2435           "Context in which the value lives.")
2436       .def(
2437           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2438           kDumpDocstring)
2439       .def_property_readonly(
2440           "owner",
2441           [](PyValue &self) {
2442             assert(mlirOperationEqual(self.getParentOperation()->get(),
2443                                       mlirOpResultGetOwner(self.get())) &&
2444                    "expected the owner of the value in Python to match that in "
2445                    "the IR");
2446             return self.getParentOperation().getObject();
2447           })
2448       .def("__eq__",
2449            [](PyValue &self, PyValue &other) {
2450              return self.get().ptr == other.get().ptr;
2451            })
2452       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2453       .def(
2454           "__str__",
2455           [](PyValue &self) {
2456             PyPrintAccumulator printAccum;
2457             printAccum.parts.append("Value(");
2458             mlirValuePrint(self.get(), printAccum.getCallback(),
2459                            printAccum.getUserData());
2460             printAccum.parts.append(")");
2461             return printAccum.join();
2462           },
2463           kValueDunderStrDocstring)
2464       .def_property_readonly("type", [](PyValue &self) {
2465         return PyType(self.getParentOperation()->getContext(),
2466                       mlirValueGetType(self.get()));
2467       });
2468   PyBlockArgument::bind(m);
2469   PyOpResult::bind(m);
2470 
2471   // Container bindings.
2472   PyBlockArgumentList::bind(m);
2473   PyBlockIterator::bind(m);
2474   PyBlockList::bind(m);
2475   PyOperationIterator::bind(m);
2476   PyOperationList::bind(m);
2477   PyOpAttributeMap::bind(m);
2478   PyOpOperandList::bind(m);
2479   PyOpResultList::bind(m);
2480   PyRegionIterator::bind(m);
2481   PyRegionList::bind(m);
2482 
2483   // Debug bindings.
2484   PyGlobalDebugFlag::bind(m);
2485 }
2486