1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug")
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
183   PyRegionIterator &dunderIter() { return *this; }
184 
185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator")
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence")
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
241   PyBlockIterator &dunderIter() { return *this; }
242 
243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator")
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList")
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
337   PyOperationIterator &dunderIter() { return *this; }
338 
339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator")
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList")
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
647                                                     {key.data(), key.size()});
648   if (mlirDialectIsNull(dialect)) {
649     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
650                      Twine("Dialect '") + key + "' not found");
651   }
652   return dialect;
653 }
654 
655 //------------------------------------------------------------------------------
656 // PyLocation
657 //------------------------------------------------------------------------------
658 
659 py::object PyLocation::getCapsule() {
660   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
661 }
662 
663 PyLocation PyLocation::createFromCapsule(py::object capsule) {
664   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
665   if (mlirLocationIsNull(rawLoc))
666     throw py::error_already_set();
667   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
668                     rawLoc);
669 }
670 
671 py::object PyLocation::contextEnter() {
672   return PyThreadContextEntry::pushLocation(*this);
673 }
674 
675 void PyLocation::contextExit(py::object excType, py::object excVal,
676                              py::object excTb) {
677   PyThreadContextEntry::popLocation(*this);
678 }
679 
680 PyLocation &DefaultingPyLocation::resolve() {
681   auto *location = PyThreadContextEntry::getDefaultLocation();
682   if (!location) {
683     throw SetPyError(
684         PyExc_RuntimeError,
685         "An MLIR function requires a Location but none was provided in the "
686         "call or from the surrounding environment. Either pass to the function "
687         "with a 'loc=' argument or establish a default using 'with loc:'");
688   }
689   return *location;
690 }
691 
692 //------------------------------------------------------------------------------
693 // PyModule
694 //------------------------------------------------------------------------------
695 
696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
697     : BaseContextObject(std::move(contextRef)), module(module) {}
698 
699 PyModule::~PyModule() {
700   py::gil_scoped_acquire acquire;
701   auto &liveModules = getContext()->liveModules;
702   assert(liveModules.count(module.ptr) == 1 &&
703          "destroying module not in live map");
704   liveModules.erase(module.ptr);
705   mlirModuleDestroy(module);
706 }
707 
708 PyModuleRef PyModule::forModule(MlirModule module) {
709   MlirContext context = mlirModuleGetContext(module);
710   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
711 
712   py::gil_scoped_acquire acquire;
713   auto &liveModules = contextRef->liveModules;
714   auto it = liveModules.find(module.ptr);
715   if (it == liveModules.end()) {
716     // Create.
717     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
718     // Note that the default return value policy on cast is automatic_reference,
719     // which does not take ownership (delete will not be called).
720     // Just be explicit.
721     py::object pyRef =
722         py::cast(unownedModule, py::return_value_policy::take_ownership);
723     unownedModule->handle = pyRef;
724     liveModules[module.ptr] =
725         std::make_pair(unownedModule->handle, unownedModule);
726     return PyModuleRef(unownedModule, std::move(pyRef));
727   }
728   // Use existing.
729   PyModule *existing = it->second.second;
730   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
731   return PyModuleRef(existing, std::move(pyRef));
732 }
733 
734 py::object PyModule::createFromCapsule(py::object capsule) {
735   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
736   if (mlirModuleIsNull(rawModule))
737     throw py::error_already_set();
738   return forModule(rawModule).releaseObject();
739 }
740 
741 py::object PyModule::getCapsule() {
742   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
743 }
744 
745 //------------------------------------------------------------------------------
746 // PyOperation
747 //------------------------------------------------------------------------------
748 
749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
750     : BaseContextObject(std::move(contextRef)), operation(operation) {}
751 
752 PyOperation::~PyOperation() {
753   // If the operation has already been invalidated there is nothing to do.
754   if (!valid)
755     return;
756   auto &liveOperations = getContext()->liveOperations;
757   assert(liveOperations.count(operation.ptr) == 1 &&
758          "destroying operation not in live map");
759   liveOperations.erase(operation.ptr);
760   if (!isAttached()) {
761     mlirOperationDestroy(operation);
762   }
763 }
764 
765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
766                                            MlirOperation operation,
767                                            py::object parentKeepAlive) {
768   auto &liveOperations = contextRef->liveOperations;
769   // Create.
770   PyOperation *unownedOperation =
771       new PyOperation(std::move(contextRef), operation);
772   // Note that the default return value policy on cast is automatic_reference,
773   // which does not take ownership (delete will not be called).
774   // Just be explicit.
775   py::object pyRef =
776       py::cast(unownedOperation, py::return_value_policy::take_ownership);
777   unownedOperation->handle = pyRef;
778   if (parentKeepAlive) {
779     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
780   }
781   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
782   return PyOperationRef(unownedOperation, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
786                                          MlirOperation operation,
787                                          py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   auto it = liveOperations.find(operation.ptr);
790   if (it == liveOperations.end()) {
791     // Create.
792     return createInstance(std::move(contextRef), operation,
793                           std::move(parentKeepAlive));
794   }
795   // Use existing.
796   PyOperation *existing = it->second.second;
797   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
798   return PyOperationRef(existing, std::move(pyRef));
799 }
800 
801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
802                                            MlirOperation operation,
803                                            py::object parentKeepAlive) {
804   auto &liveOperations = contextRef->liveOperations;
805   assert(liveOperations.count(operation.ptr) == 0 &&
806          "cannot create detached operation that already exists");
807   (void)liveOperations;
808 
809   PyOperationRef created = createInstance(std::move(contextRef), operation,
810                                           std::move(parentKeepAlive));
811   created->attached = false;
812   return created;
813 }
814 
815 void PyOperation::checkValid() const {
816   if (!valid) {
817     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
818   }
819 }
820 
821 void PyOperationBase::print(py::object fileObject, bool binary,
822                             llvm::Optional<int64_t> largeElementsLimit,
823                             bool enableDebugInfo, bool prettyDebugInfo,
824                             bool printGenericOpForm, bool useLocalScope) {
825   PyOperation &operation = getOperation();
826   operation.checkValid();
827   if (fileObject.is_none())
828     fileObject = py::module::import("sys").attr("stdout");
829 
830   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
831     fileObject.attr("write")("// Verification failed, printing generic form\n");
832     printGenericOpForm = true;
833   }
834 
835   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
836   if (largeElementsLimit)
837     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
838   if (enableDebugInfo)
839     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
840   if (printGenericOpForm)
841     mlirOpPrintingFlagsPrintGenericOpForm(flags);
842 
843   PyFileAccumulator accum(fileObject, binary);
844   py::gil_scoped_release();
845   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
846                               accum.getUserData());
847   mlirOpPrintingFlagsDestroy(flags);
848 }
849 
850 py::object PyOperationBase::getAsm(bool binary,
851                                    llvm::Optional<int64_t> largeElementsLimit,
852                                    bool enableDebugInfo, bool prettyDebugInfo,
853                                    bool printGenericOpForm,
854                                    bool useLocalScope) {
855   py::object fileObject;
856   if (binary) {
857     fileObject = py::module::import("io").attr("BytesIO")();
858   } else {
859     fileObject = py::module::import("io").attr("StringIO")();
860   }
861   print(fileObject, /*binary=*/binary,
862         /*largeElementsLimit=*/largeElementsLimit,
863         /*enableDebugInfo=*/enableDebugInfo,
864         /*prettyDebugInfo=*/prettyDebugInfo,
865         /*printGenericOpForm=*/printGenericOpForm,
866         /*useLocalScope=*/useLocalScope);
867 
868   return fileObject.attr("getvalue")();
869 }
870 
871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
872   checkValid();
873   if (!isAttached())
874     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
875   MlirOperation operation = mlirOperationGetParentOperation(get());
876   if (mlirOperationIsNull(operation))
877     return {};
878   return PyOperation::forOperation(getContext(), operation);
879 }
880 
881 PyBlock PyOperation::getBlock() {
882   checkValid();
883   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
884   MlirBlock block = mlirOperationGetBlock(get());
885   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
886   assert(parentOperation && "Operation has no parent");
887   return PyBlock{std::move(*parentOperation), block};
888 }
889 
890 py::object PyOperation::getCapsule() {
891   checkValid();
892   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
893 }
894 
895 py::object PyOperation::createFromCapsule(py::object capsule) {
896   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
897   if (mlirOperationIsNull(rawOperation))
898     throw py::error_already_set();
899   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
900   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
901       .releaseObject();
902 }
903 
904 py::object PyOperation::create(
905     std::string name, llvm::Optional<std::vector<PyType *>> results,
906     llvm::Optional<std::vector<PyValue *>> operands,
907     llvm::Optional<py::dict> attributes,
908     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
909     DefaultingPyLocation location, py::object maybeIp) {
910   llvm::SmallVector<MlirValue, 4> mlirOperands;
911   llvm::SmallVector<MlirType, 4> mlirResults;
912   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
913   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
914 
915   // General parameter validation.
916   if (regions < 0)
917     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
918 
919   // Unpack/validate operands.
920   if (operands) {
921     mlirOperands.reserve(operands->size());
922     for (PyValue *operand : *operands) {
923       if (!operand)
924         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
925       mlirOperands.push_back(operand->get());
926     }
927   }
928 
929   // Unpack/validate results.
930   if (results) {
931     mlirResults.reserve(results->size());
932     for (PyType *result : *results) {
933       // TODO: Verify result type originate from the same context.
934       if (!result)
935         throw SetPyError(PyExc_ValueError, "result type cannot be None");
936       mlirResults.push_back(*result);
937     }
938   }
939   // Unpack/validate attributes.
940   if (attributes) {
941     mlirAttributes.reserve(attributes->size());
942     for (auto &it : *attributes) {
943       std::string key;
944       try {
945         key = it.first.cast<std::string>();
946       } catch (py::cast_error &err) {
947         std::string msg = "Invalid attribute key (not a string) when "
948                           "attempting to create the operation \"" +
949                           name + "\" (" + err.what() + ")";
950         throw py::cast_error(msg);
951       }
952       try {
953         auto &attribute = it.second.cast<PyAttribute &>();
954         // TODO: Verify attribute originates from the same context.
955         mlirAttributes.emplace_back(std::move(key), attribute);
956       } catch (py::reference_cast_error &) {
957         // This exception seems thrown when the value is "None".
958         std::string msg =
959             "Found an invalid (`None`?) attribute value for the key \"" + key +
960             "\" when attempting to create the operation \"" + name + "\"";
961         throw py::cast_error(msg);
962       } catch (py::cast_error &err) {
963         std::string msg = "Invalid attribute value for the key \"" + key +
964                           "\" when attempting to create the operation \"" +
965                           name + "\" (" + err.what() + ")";
966         throw py::cast_error(msg);
967       }
968     }
969   }
970   // Unpack/validate successors.
971   if (successors) {
972     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
973     mlirSuccessors.reserve(successors->size());
974     for (auto *successor : *successors) {
975       // TODO: Verify successor originate from the same context.
976       if (!successor)
977         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
978       mlirSuccessors.push_back(successor->get());
979     }
980   }
981 
982   // Apply unpacked/validated to the operation state. Beyond this
983   // point, exceptions cannot be thrown or else the state will leak.
984   MlirOperationState state =
985       mlirOperationStateGet(toMlirStringRef(name), location);
986   if (!mlirOperands.empty())
987     mlirOperationStateAddOperands(&state, mlirOperands.size(),
988                                   mlirOperands.data());
989   if (!mlirResults.empty())
990     mlirOperationStateAddResults(&state, mlirResults.size(),
991                                  mlirResults.data());
992   if (!mlirAttributes.empty()) {
993     // Note that the attribute names directly reference bytes in
994     // mlirAttributes, so that vector must not be changed from here
995     // on.
996     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
997     mlirNamedAttributes.reserve(mlirAttributes.size());
998     for (auto &it : mlirAttributes)
999       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1000           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1001                             toMlirStringRef(it.first)),
1002           it.second));
1003     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1004                                     mlirNamedAttributes.data());
1005   }
1006   if (!mlirSuccessors.empty())
1007     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1008                                     mlirSuccessors.data());
1009   if (regions) {
1010     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1011     mlirRegions.resize(regions);
1012     for (int i = 0; i < regions; ++i)
1013       mlirRegions[i] = mlirRegionCreate();
1014     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1015                                       mlirRegions.data());
1016   }
1017 
1018   // Construct the operation.
1019   MlirOperation operation = mlirOperationCreate(&state);
1020   PyOperationRef created =
1021       PyOperation::createDetached(location->getContext(), operation);
1022 
1023   // InsertPoint active?
1024   if (!maybeIp.is(py::cast(false))) {
1025     PyInsertionPoint *ip;
1026     if (maybeIp.is_none()) {
1027       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1028     } else {
1029       ip = py::cast<PyInsertionPoint *>(maybeIp);
1030     }
1031     if (ip)
1032       ip->insert(*created.get());
1033   }
1034 
1035   return created->createOpView();
1036 }
1037 
1038 py::object PyOperation::createOpView() {
1039   checkValid();
1040   MlirIdentifier ident = mlirOperationGetName(get());
1041   MlirStringRef identStr = mlirIdentifierStr(ident);
1042   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1043       StringRef(identStr.data, identStr.length));
1044   if (opViewClass)
1045     return (*opViewClass)(getRef().getObject());
1046   return py::cast(PyOpView(getRef().getObject()));
1047 }
1048 
1049 void PyOperation::erase() {
1050   checkValid();
1051   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1052   // Python reference to a child operation is live. All children should also
1053   // have their `valid` bit set to false.
1054   auto &liveOperations = getContext()->liveOperations;
1055   if (liveOperations.count(operation.ptr))
1056     liveOperations.erase(operation.ptr);
1057   mlirOperationDestroy(operation);
1058   valid = false;
1059 }
1060 
1061 //------------------------------------------------------------------------------
1062 // PyOpView
1063 //------------------------------------------------------------------------------
1064 
1065 py::object
1066 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1067                        py::list operandList,
1068                        llvm::Optional<py::dict> attributes,
1069                        llvm::Optional<std::vector<PyBlock *>> successors,
1070                        llvm::Optional<int> regions,
1071                        DefaultingPyLocation location, py::object maybeIp) {
1072   PyMlirContextRef context = location->getContext();
1073   // Class level operation construction metadata.
1074   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1075   // Operand and result segment specs are either none, which does no
1076   // variadic unpacking, or a list of ints with segment sizes, where each
1077   // element is either a positive number (typically 1 for a scalar) or -1 to
1078   // indicate that it is derived from the length of the same-indexed operand
1079   // or result (implying that it is a list at that position).
1080   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1081   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1082 
1083   std::vector<uint32_t> operandSegmentLengths;
1084   std::vector<uint32_t> resultSegmentLengths;
1085 
1086   // Validate/determine region count.
1087   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1088   int opMinRegionCount = std::get<0>(opRegionSpec);
1089   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1090   if (!regions) {
1091     regions = opMinRegionCount;
1092   }
1093   if (*regions < opMinRegionCount) {
1094     throw py::value_error(
1095         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1096          llvm::Twine(opMinRegionCount) +
1097          " regions but was built with regions=" + llvm::Twine(*regions))
1098             .str());
1099   }
1100   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1101     throw py::value_error(
1102         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1103          llvm::Twine(opMinRegionCount) +
1104          " regions but was built with regions=" + llvm::Twine(*regions))
1105             .str());
1106   }
1107 
1108   // Unpack results.
1109   std::vector<PyType *> resultTypes;
1110   resultTypes.reserve(resultTypeList.size());
1111   if (resultSegmentSpecObj.is_none()) {
1112     // Non-variadic result unpacking.
1113     for (auto it : llvm::enumerate(resultTypeList)) {
1114       try {
1115         resultTypes.push_back(py::cast<PyType *>(it.value()));
1116         if (!resultTypes.back())
1117           throw py::cast_error();
1118       } catch (py::cast_error &err) {
1119         throw py::value_error((llvm::Twine("Result ") +
1120                                llvm::Twine(it.index()) + " of operation \"" +
1121                                name + "\" must be a Type (" + err.what() + ")")
1122                                   .str());
1123       }
1124     }
1125   } else {
1126     // Sized result unpacking.
1127     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1128     if (resultSegmentSpec.size() != resultTypeList.size()) {
1129       throw py::value_error((llvm::Twine("Operation \"") + name +
1130                              "\" requires " +
1131                              llvm::Twine(resultSegmentSpec.size()) +
1132                              "result segments but was provided " +
1133                              llvm::Twine(resultTypeList.size()))
1134                                 .str());
1135     }
1136     resultSegmentLengths.reserve(resultTypeList.size());
1137     for (auto it :
1138          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1139       int segmentSpec = std::get<1>(it.value());
1140       if (segmentSpec == 1 || segmentSpec == 0) {
1141         // Unpack unary element.
1142         try {
1143           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1144           if (resultType) {
1145             resultTypes.push_back(resultType);
1146             resultSegmentLengths.push_back(1);
1147           } else if (segmentSpec == 0) {
1148             // Allowed to be optional.
1149             resultSegmentLengths.push_back(0);
1150           } else {
1151             throw py::cast_error("was None and result is not optional");
1152           }
1153         } catch (py::cast_error &err) {
1154           throw py::value_error((llvm::Twine("Result ") +
1155                                  llvm::Twine(it.index()) + " of operation \"" +
1156                                  name + "\" must be a Type (" + err.what() +
1157                                  ")")
1158                                     .str());
1159         }
1160       } else if (segmentSpec == -1) {
1161         // Unpack sequence by appending.
1162         try {
1163           if (std::get<0>(it.value()).is_none()) {
1164             // Treat it as an empty list.
1165             resultSegmentLengths.push_back(0);
1166           } else {
1167             // Unpack the list.
1168             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1169             for (py::object segmentItem : segment) {
1170               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1171               if (!resultTypes.back()) {
1172                 throw py::cast_error("contained a None item");
1173               }
1174             }
1175             resultSegmentLengths.push_back(segment.size());
1176           }
1177         } catch (std::exception &err) {
1178           // NOTE: Sloppy to be using a catch-all here, but there are at least
1179           // three different unrelated exceptions that can be thrown in the
1180           // above "casts". Just keep the scope above small and catch them all.
1181           throw py::value_error((llvm::Twine("Result ") +
1182                                  llvm::Twine(it.index()) + " of operation \"" +
1183                                  name + "\" must be a Sequence of Types (" +
1184                                  err.what() + ")")
1185                                     .str());
1186         }
1187       } else {
1188         throw py::value_error("Unexpected segment spec");
1189       }
1190     }
1191   }
1192 
1193   // Unpack operands.
1194   std::vector<PyValue *> operands;
1195   operands.reserve(operands.size());
1196   if (operandSegmentSpecObj.is_none()) {
1197     // Non-sized operand unpacking.
1198     for (auto it : llvm::enumerate(operandList)) {
1199       try {
1200         operands.push_back(py::cast<PyValue *>(it.value()));
1201         if (!operands.back())
1202           throw py::cast_error();
1203       } catch (py::cast_error &err) {
1204         throw py::value_error((llvm::Twine("Operand ") +
1205                                llvm::Twine(it.index()) + " of operation \"" +
1206                                name + "\" must be a Value (" + err.what() + ")")
1207                                   .str());
1208       }
1209     }
1210   } else {
1211     // Sized operand unpacking.
1212     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1213     if (operandSegmentSpec.size() != operandList.size()) {
1214       throw py::value_error((llvm::Twine("Operation \"") + name +
1215                              "\" requires " +
1216                              llvm::Twine(operandSegmentSpec.size()) +
1217                              "operand segments but was provided " +
1218                              llvm::Twine(operandList.size()))
1219                                 .str());
1220     }
1221     operandSegmentLengths.reserve(operandList.size());
1222     for (auto it :
1223          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1224       int segmentSpec = std::get<1>(it.value());
1225       if (segmentSpec == 1 || segmentSpec == 0) {
1226         // Unpack unary element.
1227         try {
1228           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1229           if (operandValue) {
1230             operands.push_back(operandValue);
1231             operandSegmentLengths.push_back(1);
1232           } else if (segmentSpec == 0) {
1233             // Allowed to be optional.
1234             operandSegmentLengths.push_back(0);
1235           } else {
1236             throw py::cast_error("was None and operand is not optional");
1237           }
1238         } catch (py::cast_error &err) {
1239           throw py::value_error((llvm::Twine("Operand ") +
1240                                  llvm::Twine(it.index()) + " of operation \"" +
1241                                  name + "\" must be a Value (" + err.what() +
1242                                  ")")
1243                                     .str());
1244         }
1245       } else if (segmentSpec == -1) {
1246         // Unpack sequence by appending.
1247         try {
1248           if (std::get<0>(it.value()).is_none()) {
1249             // Treat it as an empty list.
1250             operandSegmentLengths.push_back(0);
1251           } else {
1252             // Unpack the list.
1253             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1254             for (py::object segmentItem : segment) {
1255               operands.push_back(py::cast<PyValue *>(segmentItem));
1256               if (!operands.back()) {
1257                 throw py::cast_error("contained a None item");
1258               }
1259             }
1260             operandSegmentLengths.push_back(segment.size());
1261           }
1262         } catch (std::exception &err) {
1263           // NOTE: Sloppy to be using a catch-all here, but there are at least
1264           // three different unrelated exceptions that can be thrown in the
1265           // above "casts". Just keep the scope above small and catch them all.
1266           throw py::value_error((llvm::Twine("Operand ") +
1267                                  llvm::Twine(it.index()) + " of operation \"" +
1268                                  name + "\" must be a Sequence of Values (" +
1269                                  err.what() + ")")
1270                                     .str());
1271         }
1272       } else {
1273         throw py::value_error("Unexpected segment spec");
1274       }
1275     }
1276   }
1277 
1278   // Merge operand/result segment lengths into attributes if needed.
1279   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1280     // Dup.
1281     if (attributes) {
1282       attributes = py::dict(*attributes);
1283     } else {
1284       attributes = py::dict();
1285     }
1286     if (attributes->contains("result_segment_sizes") ||
1287         attributes->contains("operand_segment_sizes")) {
1288       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1289                             "'operand_segment_sizes' attribute is unsupported. "
1290                             "Use Operation.create for such low-level access.");
1291     }
1292 
1293     // Add result_segment_sizes attribute.
1294     if (!resultSegmentLengths.empty()) {
1295       int64_t size = resultSegmentLengths.size();
1296       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1297           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1298           resultSegmentLengths.size(), resultSegmentLengths.data());
1299       (*attributes)["result_segment_sizes"] =
1300           PyAttribute(context, segmentLengthAttr);
1301     }
1302 
1303     // Add operand_segment_sizes attribute.
1304     if (!operandSegmentLengths.empty()) {
1305       int64_t size = operandSegmentLengths.size();
1306       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1307           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1308           operandSegmentLengths.size(), operandSegmentLengths.data());
1309       (*attributes)["operand_segment_sizes"] =
1310           PyAttribute(context, segmentLengthAttr);
1311     }
1312   }
1313 
1314   // Delegate to create.
1315   return PyOperation::create(std::move(name),
1316                              /*results=*/std::move(resultTypes),
1317                              /*operands=*/std::move(operands),
1318                              /*attributes=*/std::move(attributes),
1319                              /*successors=*/std::move(successors),
1320                              /*regions=*/*regions, location, maybeIp);
1321 }
1322 
1323 PyOpView::PyOpView(py::object operationObject)
1324     // Casting through the PyOperationBase base-class and then back to the
1325     // Operation lets us accept any PyOperationBase subclass.
1326     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1327       operationObject(operation.getRef().getObject()) {}
1328 
1329 py::object PyOpView::createRawSubclass(py::object userClass) {
1330   // This is... a little gross. The typical pattern is to have a pure python
1331   // class that extends OpView like:
1332   //   class AddFOp(_cext.ir.OpView):
1333   //     def __init__(self, loc, lhs, rhs):
1334   //       operation = loc.context.create_operation(
1335   //           "addf", lhs, rhs, results=[lhs.type])
1336   //       super().__init__(operation)
1337   //
1338   // I.e. The goal of the user facing type is to provide a nice constructor
1339   // that has complete freedom for the op under construction. This is at odds
1340   // with our other desire to sometimes create this object by just passing an
1341   // operation (to initialize the base class). We could do *arg and **kwargs
1342   // munging to try to make it work, but instead, we synthesize a new class
1343   // on the fly which extends this user class (AddFOp in this example) and
1344   // *give it* the base class's __init__ method, thus bypassing the
1345   // intermediate subclass's __init__ method entirely. While slightly,
1346   // underhanded, this is safe/legal because the type hierarchy has not changed
1347   // (we just added a new leaf) and we aren't mucking around with __new__.
1348   // Typically, this new class will be stored on the original as "_Raw" and will
1349   // be used for casts and other things that need a variant of the class that
1350   // is initialized purely from an operation.
1351   py::object parentMetaclass =
1352       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1353   py::dict attributes;
1354   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1355   // now.
1356   //   auto opViewType = py::type::of<PyOpView>();
1357   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1358   attributes["__init__"] = opViewType.attr("__init__");
1359   py::str origName = userClass.attr("__name__");
1360   py::str newName = py::str("_") + origName;
1361   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1362 }
1363 
1364 //------------------------------------------------------------------------------
1365 // PyInsertionPoint.
1366 //------------------------------------------------------------------------------
1367 
1368 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1369 
1370 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1371     : refOperation(beforeOperationBase.getOperation().getRef()),
1372       block((*refOperation)->getBlock()) {}
1373 
1374 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1375   PyOperation &operation = operationBase.getOperation();
1376   if (operation.isAttached())
1377     throw SetPyError(PyExc_ValueError,
1378                      "Attempt to insert operation that is already attached");
1379   block.getParentOperation()->checkValid();
1380   MlirOperation beforeOp = {nullptr};
1381   if (refOperation) {
1382     // Insert before operation.
1383     (*refOperation)->checkValid();
1384     beforeOp = (*refOperation)->get();
1385   } else {
1386     // Insert at end (before null) is only valid if the block does not
1387     // already end in a known terminator (violating this will cause assertion
1388     // failures later).
1389     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1390       throw py::index_error("Cannot insert operation at the end of a block "
1391                             "that already has a terminator. Did you mean to "
1392                             "use 'InsertionPoint.at_block_terminator(block)' "
1393                             "versus 'InsertionPoint(block)'?");
1394     }
1395   }
1396   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1397   operation.setAttached();
1398 }
1399 
1400 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1401   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1402   if (mlirOperationIsNull(firstOp)) {
1403     // Just insert at end.
1404     return PyInsertionPoint(block);
1405   }
1406 
1407   // Insert before first op.
1408   PyOperationRef firstOpRef = PyOperation::forOperation(
1409       block.getParentOperation()->getContext(), firstOp);
1410   return PyInsertionPoint{block, std::move(firstOpRef)};
1411 }
1412 
1413 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1414   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1415   if (mlirOperationIsNull(terminator))
1416     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1417   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1418       block.getParentOperation()->getContext(), terminator);
1419   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1420 }
1421 
1422 py::object PyInsertionPoint::contextEnter() {
1423   return PyThreadContextEntry::pushInsertionPoint(*this);
1424 }
1425 
1426 void PyInsertionPoint::contextExit(pybind11::object excType,
1427                                    pybind11::object excVal,
1428                                    pybind11::object excTb) {
1429   PyThreadContextEntry::popInsertionPoint(*this);
1430 }
1431 
1432 //------------------------------------------------------------------------------
1433 // PyAttribute.
1434 //------------------------------------------------------------------------------
1435 
1436 bool PyAttribute::operator==(const PyAttribute &other) {
1437   return mlirAttributeEqual(attr, other.attr);
1438 }
1439 
1440 py::object PyAttribute::getCapsule() {
1441   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1442 }
1443 
1444 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1445   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1446   if (mlirAttributeIsNull(rawAttr))
1447     throw py::error_already_set();
1448   return PyAttribute(
1449       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1450 }
1451 
1452 //------------------------------------------------------------------------------
1453 // PyNamedAttribute.
1454 //------------------------------------------------------------------------------
1455 
1456 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1457     : ownedName(new std::string(std::move(ownedName))) {
1458   namedAttr = mlirNamedAttributeGet(
1459       mlirIdentifierGet(mlirAttributeGetContext(attr),
1460                         toMlirStringRef(*this->ownedName)),
1461       attr);
1462 }
1463 
1464 //------------------------------------------------------------------------------
1465 // PyType.
1466 //------------------------------------------------------------------------------
1467 
1468 bool PyType::operator==(const PyType &other) {
1469   return mlirTypeEqual(type, other.type);
1470 }
1471 
1472 py::object PyType::getCapsule() {
1473   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1474 }
1475 
1476 PyType PyType::createFromCapsule(py::object capsule) {
1477   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1478   if (mlirTypeIsNull(rawType))
1479     throw py::error_already_set();
1480   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1481                 rawType);
1482 }
1483 
1484 //------------------------------------------------------------------------------
1485 // PyValue and subclases.
1486 //------------------------------------------------------------------------------
1487 
1488 pybind11::object PyValue::getCapsule() {
1489   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1490 }
1491 
1492 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1493   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1494   if (mlirValueIsNull(value))
1495     throw py::error_already_set();
1496   MlirOperation owner;
1497   if (mlirValueIsAOpResult(value))
1498     owner = mlirOpResultGetOwner(value);
1499   if (mlirValueIsABlockArgument(value))
1500     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1501   if (mlirOperationIsNull(owner))
1502     throw py::error_already_set();
1503   MlirContext ctx = mlirOperationGetContext(owner);
1504   PyOperationRef ownerRef =
1505       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1506   return PyValue(ownerRef, value);
1507 }
1508 
1509 namespace {
1510 /// CRTP base class for Python MLIR values that subclass Value and should be
1511 /// castable from it. The value hierarchy is one level deep and is not supposed
1512 /// to accommodate other levels unless core MLIR changes.
1513 template <typename DerivedTy>
1514 class PyConcreteValue : public PyValue {
1515 public:
1516   // Derived classes must define statics for:
1517   //   IsAFunctionTy isaFunction
1518   //   const char *pyClassName
1519   // and redefine bindDerived.
1520   using ClassTy = py::class_<DerivedTy, PyValue>;
1521   using IsAFunctionTy = bool (*)(MlirValue);
1522 
1523   PyConcreteValue() = default;
1524   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1525       : PyValue(operationRef, value) {}
1526   PyConcreteValue(PyValue &orig)
1527       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1528 
1529   /// Attempts to cast the original value to the derived type and throws on
1530   /// type mismatches.
1531   static MlirValue castFrom(PyValue &orig) {
1532     if (!DerivedTy::isaFunction(orig.get())) {
1533       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1534       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1535                                              DerivedTy::pyClassName +
1536                                              " (from " + origRepr + ")");
1537     }
1538     return orig.get();
1539   }
1540 
1541   /// Binds the Python module objects to functions of this class.
1542   static void bind(py::module &m) {
1543     auto cls = ClassTy(m, DerivedTy::pyClassName);
1544     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1545     DerivedTy::bindDerived(cls);
1546   }
1547 
1548   /// Implemented by derived classes to add methods to the Python subclass.
1549   static void bindDerived(ClassTy &m) {}
1550 };
1551 
1552 /// Python wrapper for MlirBlockArgument.
1553 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1554 public:
1555   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1556   static constexpr const char *pyClassName = "BlockArgument";
1557   using PyConcreteValue::PyConcreteValue;
1558 
1559   static void bindDerived(ClassTy &c) {
1560     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1561       return PyBlock(self.getParentOperation(),
1562                      mlirBlockArgumentGetOwner(self.get()));
1563     });
1564     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1565       return mlirBlockArgumentGetArgNumber(self.get());
1566     });
1567     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1568       return mlirBlockArgumentSetType(self.get(), type);
1569     });
1570   }
1571 };
1572 
1573 /// Python wrapper for MlirOpResult.
1574 class PyOpResult : public PyConcreteValue<PyOpResult> {
1575 public:
1576   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1577   static constexpr const char *pyClassName = "OpResult";
1578   using PyConcreteValue::PyConcreteValue;
1579 
1580   static void bindDerived(ClassTy &c) {
1581     c.def_property_readonly("owner", [](PyOpResult &self) {
1582       assert(
1583           mlirOperationEqual(self.getParentOperation()->get(),
1584                              mlirOpResultGetOwner(self.get())) &&
1585           "expected the owner of the value in Python to match that in the IR");
1586       return self.getParentOperation().getObject();
1587     });
1588     c.def_property_readonly("result_number", [](PyOpResult &self) {
1589       return mlirOpResultGetResultNumber(self.get());
1590     });
1591   }
1592 };
1593 
1594 /// A list of block arguments. Internally, these are stored as consecutive
1595 /// elements, random access is cheap. The argument list is associated with the
1596 /// operation that contains the block (detached blocks are not allowed in
1597 /// Python bindings) and extends its lifetime.
1598 class PyBlockArgumentList {
1599 public:
1600   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1601       : operation(std::move(operation)), block(block) {}
1602 
1603   /// Returns the length of the block argument list.
1604   intptr_t dunderLen() {
1605     operation->checkValid();
1606     return mlirBlockGetNumArguments(block);
1607   }
1608 
1609   /// Returns `index`-th element of the block argument list.
1610   PyBlockArgument dunderGetItem(intptr_t index) {
1611     if (index < 0 || index >= dunderLen()) {
1612       throw SetPyError(PyExc_IndexError,
1613                        "attempt to access out of bounds region");
1614     }
1615     PyValue value(operation, mlirBlockGetArgument(block, index));
1616     return PyBlockArgument(value);
1617   }
1618 
1619   /// Defines a Python class in the bindings.
1620   static void bind(py::module &m) {
1621     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1622         .def("__len__", &PyBlockArgumentList::dunderLen)
1623         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1624   }
1625 
1626 private:
1627   PyOperationRef operation;
1628   MlirBlock block;
1629 };
1630 
1631 /// A list of operation operands. Internally, these are stored as consecutive
1632 /// elements, random access is cheap. The result list is associated with the
1633 /// operation whose results these are, and extends the lifetime of this
1634 /// operation.
1635 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1636 public:
1637   static constexpr const char *pyClassName = "OpOperandList";
1638 
1639   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1640                   intptr_t length = -1, intptr_t step = 1)
1641       : Sliceable(startIndex,
1642                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1643                                : length,
1644                   step),
1645         operation(operation) {}
1646 
1647   intptr_t getNumElements() {
1648     operation->checkValid();
1649     return mlirOperationGetNumOperands(operation->get());
1650   }
1651 
1652   PyValue getElement(intptr_t pos) {
1653     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1654     MlirOperation owner;
1655     if (mlirValueIsAOpResult(operand))
1656       owner = mlirOpResultGetOwner(operand);
1657     else if (mlirValueIsABlockArgument(operand))
1658       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1659     else
1660       assert(false && "Value must be an block arg or op result.");
1661     PyOperationRef pyOwner =
1662         PyOperation::forOperation(operation->getContext(), owner);
1663     return PyValue(pyOwner, operand);
1664   }
1665 
1666   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1667     return PyOpOperandList(operation, startIndex, length, step);
1668   }
1669 
1670   void dunderSetItem(intptr_t index, PyValue value) {
1671     index = wrapIndex(index);
1672     mlirOperationSetOperand(operation->get(), index, value.get());
1673   }
1674 
1675   static void bindDerived(ClassTy &c) {
1676     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1677   }
1678 
1679 private:
1680   PyOperationRef operation;
1681 };
1682 
1683 /// A list of operation results. Internally, these are stored as consecutive
1684 /// elements, random access is cheap. The result list is associated with the
1685 /// operation whose results these are, and extends the lifetime of this
1686 /// operation.
1687 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1688 public:
1689   static constexpr const char *pyClassName = "OpResultList";
1690 
1691   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1692                  intptr_t length = -1, intptr_t step = 1)
1693       : Sliceable(startIndex,
1694                   length == -1 ? mlirOperationGetNumResults(operation->get())
1695                                : length,
1696                   step),
1697         operation(operation) {}
1698 
1699   intptr_t getNumElements() {
1700     operation->checkValid();
1701     return mlirOperationGetNumResults(operation->get());
1702   }
1703 
1704   PyOpResult getElement(intptr_t index) {
1705     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1706     return PyOpResult(value);
1707   }
1708 
1709   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1710     return PyOpResultList(operation, startIndex, length, step);
1711   }
1712 
1713 private:
1714   PyOperationRef operation;
1715 };
1716 
1717 /// A list of operation attributes. Can be indexed by name, producing
1718 /// attributes, or by index, producing named attributes.
1719 class PyOpAttributeMap {
1720 public:
1721   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1722 
1723   PyAttribute dunderGetItemNamed(const std::string &name) {
1724     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1725                                                          toMlirStringRef(name));
1726     if (mlirAttributeIsNull(attr)) {
1727       throw SetPyError(PyExc_KeyError,
1728                        "attempt to access a non-existent attribute");
1729     }
1730     return PyAttribute(operation->getContext(), attr);
1731   }
1732 
1733   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1734     if (index < 0 || index >= dunderLen()) {
1735       throw SetPyError(PyExc_IndexError,
1736                        "attempt to access out of bounds attribute");
1737     }
1738     MlirNamedAttribute namedAttr =
1739         mlirOperationGetAttribute(operation->get(), index);
1740     return PyNamedAttribute(
1741         namedAttr.attribute,
1742         std::string(mlirIdentifierStr(namedAttr.name).data));
1743   }
1744 
1745   void dunderSetItem(const std::string &name, PyAttribute attr) {
1746     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1747                                     attr);
1748   }
1749 
1750   void dunderDelItem(const std::string &name) {
1751     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1752                                                      toMlirStringRef(name));
1753     if (!removed)
1754       throw SetPyError(PyExc_KeyError,
1755                        "attempt to delete a non-existent attribute");
1756   }
1757 
1758   intptr_t dunderLen() {
1759     return mlirOperationGetNumAttributes(operation->get());
1760   }
1761 
1762   bool dunderContains(const std::string &name) {
1763     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1764         operation->get(), toMlirStringRef(name)));
1765   }
1766 
1767   static void bind(py::module &m) {
1768     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1769         .def("__contains__", &PyOpAttributeMap::dunderContains)
1770         .def("__len__", &PyOpAttributeMap::dunderLen)
1771         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1772         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1773         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1774         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1775   }
1776 
1777 private:
1778   PyOperationRef operation;
1779 };
1780 
1781 } // end namespace
1782 
1783 //------------------------------------------------------------------------------
1784 // Populates the core exports of the 'ir' submodule.
1785 //------------------------------------------------------------------------------
1786 
1787 void mlir::python::populateIRCore(py::module &m) {
1788   //----------------------------------------------------------------------------
1789   // Mapping of MlirContext.
1790   //----------------------------------------------------------------------------
1791   py::class_<PyMlirContext>(m, "Context")
1792       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1793       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1794       .def("_get_context_again",
1795            [](PyMlirContext &self) {
1796              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1797              return ref.releaseObject();
1798            })
1799       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1800       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1801       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1802                              &PyMlirContext::getCapsule)
1803       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1804       .def("__enter__", &PyMlirContext::contextEnter)
1805       .def("__exit__", &PyMlirContext::contextExit)
1806       .def_property_readonly_static(
1807           "current",
1808           [](py::object & /*class*/) {
1809             auto *context = PyThreadContextEntry::getDefaultContext();
1810             if (!context)
1811               throw SetPyError(PyExc_ValueError, "No current Context");
1812             return context;
1813           },
1814           "Gets the Context bound to the current thread or raises ValueError")
1815       .def_property_readonly(
1816           "dialects",
1817           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1818           "Gets a container for accessing dialects by name")
1819       .def_property_readonly(
1820           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1821           "Alias for 'dialect'")
1822       .def(
1823           "get_dialect_descriptor",
1824           [=](PyMlirContext &self, std::string &name) {
1825             MlirDialect dialect = mlirContextGetOrLoadDialect(
1826                 self.get(), {name.data(), name.size()});
1827             if (mlirDialectIsNull(dialect)) {
1828               throw SetPyError(PyExc_ValueError,
1829                                Twine("Dialect '") + name + "' not found");
1830             }
1831             return PyDialectDescriptor(self.getRef(), dialect);
1832           },
1833           "Gets or loads a dialect by name, returning its descriptor object")
1834       .def_property(
1835           "allow_unregistered_dialects",
1836           [](PyMlirContext &self) -> bool {
1837             return mlirContextGetAllowUnregisteredDialects(self.get());
1838           },
1839           [](PyMlirContext &self, bool value) {
1840             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1841           })
1842       .def("enable_multithreading",
1843            [](PyMlirContext &self, bool enable) {
1844              mlirContextEnableMultithreading(self.get(), enable);
1845            })
1846       .def("is_registered_operation",
1847            [](PyMlirContext &self, std::string &name) {
1848              return mlirContextIsRegisteredOperation(
1849                  self.get(), MlirStringRef{name.data(), name.size()});
1850            });
1851 
1852   //----------------------------------------------------------------------------
1853   // Mapping of PyDialectDescriptor
1854   //----------------------------------------------------------------------------
1855   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1856       .def_property_readonly("namespace",
1857                              [](PyDialectDescriptor &self) {
1858                                MlirStringRef ns =
1859                                    mlirDialectGetNamespace(self.get());
1860                                return py::str(ns.data, ns.length);
1861                              })
1862       .def("__repr__", [](PyDialectDescriptor &self) {
1863         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1864         std::string repr("<DialectDescriptor ");
1865         repr.append(ns.data, ns.length);
1866         repr.append(">");
1867         return repr;
1868       });
1869 
1870   //----------------------------------------------------------------------------
1871   // Mapping of PyDialects
1872   //----------------------------------------------------------------------------
1873   py::class_<PyDialects>(m, "Dialects")
1874       .def("__getitem__",
1875            [=](PyDialects &self, std::string keyName) {
1876              MlirDialect dialect =
1877                  self.getDialectForKey(keyName, /*attrError=*/false);
1878              py::object descriptor =
1879                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1880              return createCustomDialectWrapper(keyName, std::move(descriptor));
1881            })
1882       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1883         MlirDialect dialect =
1884             self.getDialectForKey(attrName, /*attrError=*/true);
1885         py::object descriptor =
1886             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1887         return createCustomDialectWrapper(attrName, std::move(descriptor));
1888       });
1889 
1890   //----------------------------------------------------------------------------
1891   // Mapping of PyDialect
1892   //----------------------------------------------------------------------------
1893   py::class_<PyDialect>(m, "Dialect")
1894       .def(py::init<py::object>(), "descriptor")
1895       .def_property_readonly(
1896           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1897       .def("__repr__", [](py::object self) {
1898         auto clazz = self.attr("__class__");
1899         return py::str("<Dialect ") +
1900                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1901                clazz.attr("__module__") + py::str(".") +
1902                clazz.attr("__name__") + py::str(")>");
1903       });
1904 
1905   //----------------------------------------------------------------------------
1906   // Mapping of Location
1907   //----------------------------------------------------------------------------
1908   py::class_<PyLocation>(m, "Location")
1909       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1910       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1911       .def("__enter__", &PyLocation::contextEnter)
1912       .def("__exit__", &PyLocation::contextExit)
1913       .def("__eq__",
1914            [](PyLocation &self, PyLocation &other) -> bool {
1915              return mlirLocationEqual(self, other);
1916            })
1917       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1918       .def_property_readonly_static(
1919           "current",
1920           [](py::object & /*class*/) {
1921             auto *loc = PyThreadContextEntry::getDefaultLocation();
1922             if (!loc)
1923               throw SetPyError(PyExc_ValueError, "No current Location");
1924             return loc;
1925           },
1926           "Gets the Location bound to the current thread or raises ValueError")
1927       .def_static(
1928           "unknown",
1929           [](DefaultingPyMlirContext context) {
1930             return PyLocation(context->getRef(),
1931                               mlirLocationUnknownGet(context->get()));
1932           },
1933           py::arg("context") = py::none(),
1934           "Gets a Location representing an unknown location")
1935       .def_static(
1936           "file",
1937           [](std::string filename, int line, int col,
1938              DefaultingPyMlirContext context) {
1939             return PyLocation(
1940                 context->getRef(),
1941                 mlirLocationFileLineColGet(
1942                     context->get(), toMlirStringRef(filename), line, col));
1943           },
1944           py::arg("filename"), py::arg("line"), py::arg("col"),
1945           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1946       .def_property_readonly(
1947           "context",
1948           [](PyLocation &self) { return self.getContext().getObject(); },
1949           "Context that owns the Location")
1950       .def("__repr__", [](PyLocation &self) {
1951         PyPrintAccumulator printAccum;
1952         mlirLocationPrint(self, printAccum.getCallback(),
1953                           printAccum.getUserData());
1954         return printAccum.join();
1955       });
1956 
1957   //----------------------------------------------------------------------------
1958   // Mapping of Module
1959   //----------------------------------------------------------------------------
1960   py::class_<PyModule>(m, "Module")
1961       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1962       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1963       .def_static(
1964           "parse",
1965           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1966             MlirModule module = mlirModuleCreateParse(
1967                 context->get(), toMlirStringRef(moduleAsm));
1968             // TODO: Rework error reporting once diagnostic engine is exposed
1969             // in C API.
1970             if (mlirModuleIsNull(module)) {
1971               throw SetPyError(
1972                   PyExc_ValueError,
1973                   "Unable to parse module assembly (see diagnostics)");
1974             }
1975             return PyModule::forModule(module).releaseObject();
1976           },
1977           py::arg("asm"), py::arg("context") = py::none(),
1978           kModuleParseDocstring)
1979       .def_static(
1980           "create",
1981           [](DefaultingPyLocation loc) {
1982             MlirModule module = mlirModuleCreateEmpty(loc);
1983             return PyModule::forModule(module).releaseObject();
1984           },
1985           py::arg("loc") = py::none(), "Creates an empty module")
1986       .def_property_readonly(
1987           "context",
1988           [](PyModule &self) { return self.getContext().getObject(); },
1989           "Context that created the Module")
1990       .def_property_readonly(
1991           "operation",
1992           [](PyModule &self) {
1993             return PyOperation::forOperation(self.getContext(),
1994                                              mlirModuleGetOperation(self.get()),
1995                                              self.getRef().releaseObject())
1996                 .releaseObject();
1997           },
1998           "Accesses the module as an operation")
1999       .def_property_readonly(
2000           "body",
2001           [](PyModule &self) {
2002             PyOperationRef module_op = PyOperation::forOperation(
2003                 self.getContext(), mlirModuleGetOperation(self.get()),
2004                 self.getRef().releaseObject());
2005             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2006             return returnBlock;
2007           },
2008           "Return the block for this module")
2009       .def(
2010           "dump",
2011           [](PyModule &self) {
2012             mlirOperationDump(mlirModuleGetOperation(self.get()));
2013           },
2014           kDumpDocstring)
2015       .def(
2016           "__str__",
2017           [](PyModule &self) {
2018             MlirOperation operation = mlirModuleGetOperation(self.get());
2019             PyPrintAccumulator printAccum;
2020             mlirOperationPrint(operation, printAccum.getCallback(),
2021                                printAccum.getUserData());
2022             return printAccum.join();
2023           },
2024           kOperationStrDunderDocstring);
2025 
2026   //----------------------------------------------------------------------------
2027   // Mapping of Operation.
2028   //----------------------------------------------------------------------------
2029   py::class_<PyOperationBase>(m, "_OperationBase")
2030       .def("__eq__",
2031            [](PyOperationBase &self, PyOperationBase &other) {
2032              return &self.getOperation() == &other.getOperation();
2033            })
2034       .def("__eq__",
2035            [](PyOperationBase &self, py::object other) { return false; })
2036       .def_property_readonly("attributes",
2037                              [](PyOperationBase &self) {
2038                                return PyOpAttributeMap(
2039                                    self.getOperation().getRef());
2040                              })
2041       .def_property_readonly("operands",
2042                              [](PyOperationBase &self) {
2043                                return PyOpOperandList(
2044                                    self.getOperation().getRef());
2045                              })
2046       .def_property_readonly("regions",
2047                              [](PyOperationBase &self) {
2048                                return PyRegionList(
2049                                    self.getOperation().getRef());
2050                              })
2051       .def_property_readonly(
2052           "results",
2053           [](PyOperationBase &self) {
2054             return PyOpResultList(self.getOperation().getRef());
2055           },
2056           "Returns the list of Operation results.")
2057       .def_property_readonly(
2058           "result",
2059           [](PyOperationBase &self) {
2060             auto &operation = self.getOperation();
2061             auto numResults = mlirOperationGetNumResults(operation);
2062             if (numResults != 1) {
2063               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2064               throw SetPyError(
2065                   PyExc_ValueError,
2066                   Twine("Cannot call .result on operation ") +
2067                       StringRef(name.data, name.length) + " which has " +
2068                       Twine(numResults) +
2069                       " results (it is only valid for operations with a "
2070                       "single result)");
2071             }
2072             return PyOpResult(operation.getRef(),
2073                               mlirOperationGetResult(operation, 0));
2074           },
2075           "Shortcut to get an op result if it has only one (throws an error "
2076           "otherwise).")
2077       .def("__iter__",
2078            [](PyOperationBase &self) {
2079              return PyRegionIterator(self.getOperation().getRef());
2080            })
2081       .def(
2082           "__str__",
2083           [](PyOperationBase &self) {
2084             return self.getAsm(/*binary=*/false,
2085                                /*largeElementsLimit=*/llvm::None,
2086                                /*enableDebugInfo=*/false,
2087                                /*prettyDebugInfo=*/false,
2088                                /*printGenericOpForm=*/false,
2089                                /*useLocalScope=*/false);
2090           },
2091           "Returns the assembly form of the operation.")
2092       .def("print", &PyOperationBase::print,
2093            // Careful: Lots of arguments must match up with print method.
2094            py::arg("file") = py::none(), py::arg("binary") = false,
2095            py::arg("large_elements_limit") = py::none(),
2096            py::arg("enable_debug_info") = false,
2097            py::arg("pretty_debug_info") = false,
2098            py::arg("print_generic_op_form") = false,
2099            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2100       .def("get_asm", &PyOperationBase::getAsm,
2101            // Careful: Lots of arguments must match up with get_asm method.
2102            py::arg("binary") = false,
2103            py::arg("large_elements_limit") = py::none(),
2104            py::arg("enable_debug_info") = false,
2105            py::arg("pretty_debug_info") = false,
2106            py::arg("print_generic_op_form") = false,
2107            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2108       .def(
2109           "verify",
2110           [](PyOperationBase &self) {
2111             return mlirOperationVerify(self.getOperation());
2112           },
2113           "Verify the operation and return true if it passes, false if it "
2114           "fails.");
2115 
2116   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2117       .def_static("create", &PyOperation::create, py::arg("name"),
2118                   py::arg("results") = py::none(),
2119                   py::arg("operands") = py::none(),
2120                   py::arg("attributes") = py::none(),
2121                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2122                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2123                   kOperationCreateDocstring)
2124       .def_property_readonly("parent",
2125                              [](PyOperation &self) -> py::object {
2126                                auto parent = self.getParentOperation();
2127                                if (parent)
2128                                  return parent->getObject();
2129                                return py::none();
2130                              })
2131       .def("erase", &PyOperation::erase)
2132       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2133                              &PyOperation::getCapsule)
2134       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2135       .def_property_readonly("name",
2136                              [](PyOperation &self) {
2137                                self.checkValid();
2138                                MlirOperation operation = self.get();
2139                                MlirStringRef name = mlirIdentifierStr(
2140                                    mlirOperationGetName(operation));
2141                                return py::str(name.data, name.length);
2142                              })
2143       .def_property_readonly(
2144           "context",
2145           [](PyOperation &self) {
2146             self.checkValid();
2147             return self.getContext().getObject();
2148           },
2149           "Context that owns the Operation")
2150       .def_property_readonly("opview", &PyOperation::createOpView);
2151 
2152   auto opViewClass =
2153       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2154           .def(py::init<py::object>())
2155           .def_property_readonly("operation", &PyOpView::getOperationObject)
2156           .def_property_readonly(
2157               "context",
2158               [](PyOpView &self) {
2159                 return self.getOperation().getContext().getObject();
2160               },
2161               "Context that owns the Operation")
2162           .def("__str__", [](PyOpView &self) {
2163             return py::str(self.getOperationObject());
2164           });
2165   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2166   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2167   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2168   opViewClass.attr("build_generic") = classmethod(
2169       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2170       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2171       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2172       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2173       "Builds a specific, generated OpView based on class level attributes.");
2174 
2175   //----------------------------------------------------------------------------
2176   // Mapping of PyRegion.
2177   //----------------------------------------------------------------------------
2178   py::class_<PyRegion>(m, "Region")
2179       .def_property_readonly(
2180           "blocks",
2181           [](PyRegion &self) {
2182             return PyBlockList(self.getParentOperation(), self.get());
2183           },
2184           "Returns a forward-optimized sequence of blocks.")
2185       .def(
2186           "__iter__",
2187           [](PyRegion &self) {
2188             self.checkValid();
2189             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2190             return PyBlockIterator(self.getParentOperation(), firstBlock);
2191           },
2192           "Iterates over blocks in the region.")
2193       .def("__eq__",
2194            [](PyRegion &self, PyRegion &other) {
2195              return self.get().ptr == other.get().ptr;
2196            })
2197       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2198 
2199   //----------------------------------------------------------------------------
2200   // Mapping of PyBlock.
2201   //----------------------------------------------------------------------------
2202   py::class_<PyBlock>(m, "Block")
2203       .def_property_readonly(
2204           "owner",
2205           [](PyBlock &self) {
2206             return self.getParentOperation()->createOpView();
2207           },
2208           "Returns the owning operation of this block.")
2209       .def_property_readonly(
2210           "arguments",
2211           [](PyBlock &self) {
2212             return PyBlockArgumentList(self.getParentOperation(), self.get());
2213           },
2214           "Returns a list of block arguments.")
2215       .def_property_readonly(
2216           "operations",
2217           [](PyBlock &self) {
2218             return PyOperationList(self.getParentOperation(), self.get());
2219           },
2220           "Returns a forward-optimized sequence of operations.")
2221       .def(
2222           "__iter__",
2223           [](PyBlock &self) {
2224             self.checkValid();
2225             MlirOperation firstOperation =
2226                 mlirBlockGetFirstOperation(self.get());
2227             return PyOperationIterator(self.getParentOperation(),
2228                                        firstOperation);
2229           },
2230           "Iterates over operations in the block.")
2231       .def("__eq__",
2232            [](PyBlock &self, PyBlock &other) {
2233              return self.get().ptr == other.get().ptr;
2234            })
2235       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2236       .def(
2237           "__str__",
2238           [](PyBlock &self) {
2239             self.checkValid();
2240             PyPrintAccumulator printAccum;
2241             mlirBlockPrint(self.get(), printAccum.getCallback(),
2242                            printAccum.getUserData());
2243             return printAccum.join();
2244           },
2245           "Returns the assembly form of the block.");
2246 
2247   //----------------------------------------------------------------------------
2248   // Mapping of PyInsertionPoint.
2249   //----------------------------------------------------------------------------
2250 
2251   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2252       .def(py::init<PyBlock &>(), py::arg("block"),
2253            "Inserts after the last operation but still inside the block.")
2254       .def("__enter__", &PyInsertionPoint::contextEnter)
2255       .def("__exit__", &PyInsertionPoint::contextExit)
2256       .def_property_readonly_static(
2257           "current",
2258           [](py::object & /*class*/) {
2259             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2260             if (!ip)
2261               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2262             return ip;
2263           },
2264           "Gets the InsertionPoint bound to the current thread or raises "
2265           "ValueError if none has been set")
2266       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2267            "Inserts before a referenced operation.")
2268       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2269                   py::arg("block"), "Inserts at the beginning of the block.")
2270       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2271                   py::arg("block"), "Inserts before the block terminator.")
2272       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2273            "Inserts an operation.");
2274 
2275   //----------------------------------------------------------------------------
2276   // Mapping of PyAttribute.
2277   //----------------------------------------------------------------------------
2278   py::class_<PyAttribute>(m, "Attribute")
2279       // Delegate to the PyAttribute copy constructor, which will also lifetime
2280       // extend the backing context which owns the MlirAttribute.
2281       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2282            "Casts the passed attribute to the generic Attribute")
2283       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2284                              &PyAttribute::getCapsule)
2285       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2286       .def_static(
2287           "parse",
2288           [](std::string attrSpec, DefaultingPyMlirContext context) {
2289             MlirAttribute type = mlirAttributeParseGet(
2290                 context->get(), toMlirStringRef(attrSpec));
2291             // TODO: Rework error reporting once diagnostic engine is exposed
2292             // in C API.
2293             if (mlirAttributeIsNull(type)) {
2294               throw SetPyError(PyExc_ValueError,
2295                                Twine("Unable to parse attribute: '") +
2296                                    attrSpec + "'");
2297             }
2298             return PyAttribute(context->getRef(), type);
2299           },
2300           py::arg("asm"), py::arg("context") = py::none(),
2301           "Parses an attribute from an assembly form")
2302       .def_property_readonly(
2303           "context",
2304           [](PyAttribute &self) { return self.getContext().getObject(); },
2305           "Context that owns the Attribute")
2306       .def_property_readonly("type",
2307                              [](PyAttribute &self) {
2308                                return PyType(self.getContext()->getRef(),
2309                                              mlirAttributeGetType(self));
2310                              })
2311       .def(
2312           "get_named",
2313           [](PyAttribute &self, std::string name) {
2314             return PyNamedAttribute(self, std::move(name));
2315           },
2316           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2317       .def("__eq__",
2318            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2319       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2320       .def(
2321           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2322           kDumpDocstring)
2323       .def(
2324           "__str__",
2325           [](PyAttribute &self) {
2326             PyPrintAccumulator printAccum;
2327             mlirAttributePrint(self, printAccum.getCallback(),
2328                                printAccum.getUserData());
2329             return printAccum.join();
2330           },
2331           "Returns the assembly form of the Attribute.")
2332       .def("__repr__", [](PyAttribute &self) {
2333         // Generally, assembly formats are not printed for __repr__ because
2334         // this can cause exceptionally long debug output and exceptions.
2335         // However, attribute values are generally considered useful and are
2336         // printed. This may need to be re-evaluated if debug dumps end up
2337         // being excessive.
2338         PyPrintAccumulator printAccum;
2339         printAccum.parts.append("Attribute(");
2340         mlirAttributePrint(self, printAccum.getCallback(),
2341                            printAccum.getUserData());
2342         printAccum.parts.append(")");
2343         return printAccum.join();
2344       });
2345 
2346   //----------------------------------------------------------------------------
2347   // Mapping of PyNamedAttribute
2348   //----------------------------------------------------------------------------
2349   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2350       .def("__repr__",
2351            [](PyNamedAttribute &self) {
2352              PyPrintAccumulator printAccum;
2353              printAccum.parts.append("NamedAttribute(");
2354              printAccum.parts.append(
2355                  mlirIdentifierStr(self.namedAttr.name).data);
2356              printAccum.parts.append("=");
2357              mlirAttributePrint(self.namedAttr.attribute,
2358                                 printAccum.getCallback(),
2359                                 printAccum.getUserData());
2360              printAccum.parts.append(")");
2361              return printAccum.join();
2362            })
2363       .def_property_readonly(
2364           "name",
2365           [](PyNamedAttribute &self) {
2366             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2367                            mlirIdentifierStr(self.namedAttr.name).length);
2368           },
2369           "The name of the NamedAttribute binding")
2370       .def_property_readonly(
2371           "attr",
2372           [](PyNamedAttribute &self) {
2373             // TODO: When named attribute is removed/refactored, also remove
2374             // this constructor (it does an inefficient table lookup).
2375             auto contextRef = PyMlirContext::forContext(
2376                 mlirAttributeGetContext(self.namedAttr.attribute));
2377             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2378           },
2379           py::keep_alive<0, 1>(),
2380           "The underlying generic attribute of the NamedAttribute binding");
2381 
2382   //----------------------------------------------------------------------------
2383   // Mapping of PyType.
2384   //----------------------------------------------------------------------------
2385   py::class_<PyType>(m, "Type")
2386       // Delegate to the PyType copy constructor, which will also lifetime
2387       // extend the backing context which owns the MlirType.
2388       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2389            "Casts the passed type to the generic Type")
2390       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2391       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2392       .def_static(
2393           "parse",
2394           [](std::string typeSpec, DefaultingPyMlirContext context) {
2395             MlirType type =
2396                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2397             // TODO: Rework error reporting once diagnostic engine is exposed
2398             // in C API.
2399             if (mlirTypeIsNull(type)) {
2400               throw SetPyError(PyExc_ValueError,
2401                                Twine("Unable to parse type: '") + typeSpec +
2402                                    "'");
2403             }
2404             return PyType(context->getRef(), type);
2405           },
2406           py::arg("asm"), py::arg("context") = py::none(),
2407           kContextParseTypeDocstring)
2408       .def_property_readonly(
2409           "context", [](PyType &self) { return self.getContext().getObject(); },
2410           "Context that owns the Type")
2411       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2412       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2413       .def(
2414           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2415       .def(
2416           "__str__",
2417           [](PyType &self) {
2418             PyPrintAccumulator printAccum;
2419             mlirTypePrint(self, printAccum.getCallback(),
2420                           printAccum.getUserData());
2421             return printAccum.join();
2422           },
2423           "Returns the assembly form of the type.")
2424       .def("__repr__", [](PyType &self) {
2425         // Generally, assembly formats are not printed for __repr__ because
2426         // this can cause exceptionally long debug output and exceptions.
2427         // However, types are an exception as they typically have compact
2428         // assembly forms and printing them is useful.
2429         PyPrintAccumulator printAccum;
2430         printAccum.parts.append("Type(");
2431         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2432         printAccum.parts.append(")");
2433         return printAccum.join();
2434       });
2435 
2436   //----------------------------------------------------------------------------
2437   // Mapping of Value.
2438   //----------------------------------------------------------------------------
2439   py::class_<PyValue>(m, "Value")
2440       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2441       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2442       .def_property_readonly(
2443           "context",
2444           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2445           "Context in which the value lives.")
2446       .def(
2447           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2448           kDumpDocstring)
2449       .def_property_readonly(
2450           "owner",
2451           [](PyValue &self) {
2452             assert(mlirOperationEqual(self.getParentOperation()->get(),
2453                                       mlirOpResultGetOwner(self.get())) &&
2454                    "expected the owner of the value in Python to match that in "
2455                    "the IR");
2456             return self.getParentOperation().getObject();
2457           })
2458       .def("__eq__",
2459            [](PyValue &self, PyValue &other) {
2460              return self.get().ptr == other.get().ptr;
2461            })
2462       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2463       .def(
2464           "__str__",
2465           [](PyValue &self) {
2466             PyPrintAccumulator printAccum;
2467             printAccum.parts.append("Value(");
2468             mlirValuePrint(self.get(), printAccum.getCallback(),
2469                            printAccum.getUserData());
2470             printAccum.parts.append(")");
2471             return printAccum.join();
2472           },
2473           kValueDunderStrDocstring)
2474       .def_property_readonly("type", [](PyValue &self) {
2475         return PyType(self.getParentOperation()->getContext(),
2476                       mlirValueGetType(self.get()));
2477       });
2478   PyBlockArgument::bind(m);
2479   PyOpResult::bind(m);
2480 
2481   // Container bindings.
2482   PyBlockArgumentList::bind(m);
2483   PyBlockIterator::bind(m);
2484   PyBlockList::bind(m);
2485   PyOperationIterator::bind(m);
2486   PyOperationList::bind(m);
2487   PyOpAttributeMap::bind(m);
2488   PyOpOperandList::bind(m);
2489   PyOpResultList::bind(m);
2490   PyRegionIterator::bind(m);
2491   PyRegionList::bind(m);
2492 
2493   // Debug bindings.
2494   PyGlobalDebugFlag::bind(m);
2495 }
2496