1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
183   PyRegionIterator &dunderIter() { return *this; }
184 
185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
241   PyBlockIterator &dunderIter() { return *this; }
242 
243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList", py::module_local())
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
337   PyOperationIterator &dunderIter() { return *this; }
338 
339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList", py::module_local())
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
647                                                     {key.data(), key.size()});
648   if (mlirDialectIsNull(dialect)) {
649     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
650                      Twine("Dialect '") + key + "' not found");
651   }
652   return dialect;
653 }
654 
655 //------------------------------------------------------------------------------
656 // PyLocation
657 //------------------------------------------------------------------------------
658 
659 py::object PyLocation::getCapsule() {
660   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
661 }
662 
663 PyLocation PyLocation::createFromCapsule(py::object capsule) {
664   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
665   if (mlirLocationIsNull(rawLoc))
666     throw py::error_already_set();
667   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
668                     rawLoc);
669 }
670 
671 py::object PyLocation::contextEnter() {
672   return PyThreadContextEntry::pushLocation(*this);
673 }
674 
675 void PyLocation::contextExit(py::object excType, py::object excVal,
676                              py::object excTb) {
677   PyThreadContextEntry::popLocation(*this);
678 }
679 
680 PyLocation &DefaultingPyLocation::resolve() {
681   auto *location = PyThreadContextEntry::getDefaultLocation();
682   if (!location) {
683     throw SetPyError(
684         PyExc_RuntimeError,
685         "An MLIR function requires a Location but none was provided in the "
686         "call or from the surrounding environment. Either pass to the function "
687         "with a 'loc=' argument or establish a default using 'with loc:'");
688   }
689   return *location;
690 }
691 
692 //------------------------------------------------------------------------------
693 // PyModule
694 //------------------------------------------------------------------------------
695 
696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
697     : BaseContextObject(std::move(contextRef)), module(module) {}
698 
699 PyModule::~PyModule() {
700   py::gil_scoped_acquire acquire;
701   auto &liveModules = getContext()->liveModules;
702   assert(liveModules.count(module.ptr) == 1 &&
703          "destroying module not in live map");
704   liveModules.erase(module.ptr);
705   mlirModuleDestroy(module);
706 }
707 
708 PyModuleRef PyModule::forModule(MlirModule module) {
709   MlirContext context = mlirModuleGetContext(module);
710   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
711 
712   py::gil_scoped_acquire acquire;
713   auto &liveModules = contextRef->liveModules;
714   auto it = liveModules.find(module.ptr);
715   if (it == liveModules.end()) {
716     // Create.
717     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
718     // Note that the default return value policy on cast is automatic_reference,
719     // which does not take ownership (delete will not be called).
720     // Just be explicit.
721     py::object pyRef =
722         py::cast(unownedModule, py::return_value_policy::take_ownership);
723     unownedModule->handle = pyRef;
724     liveModules[module.ptr] =
725         std::make_pair(unownedModule->handle, unownedModule);
726     return PyModuleRef(unownedModule, std::move(pyRef));
727   }
728   // Use existing.
729   PyModule *existing = it->second.second;
730   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
731   return PyModuleRef(existing, std::move(pyRef));
732 }
733 
734 py::object PyModule::createFromCapsule(py::object capsule) {
735   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
736   if (mlirModuleIsNull(rawModule))
737     throw py::error_already_set();
738   return forModule(rawModule).releaseObject();
739 }
740 
741 py::object PyModule::getCapsule() {
742   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
743 }
744 
745 //------------------------------------------------------------------------------
746 // PyOperation
747 //------------------------------------------------------------------------------
748 
749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
750     : BaseContextObject(std::move(contextRef)), operation(operation) {}
751 
752 PyOperation::~PyOperation() {
753   // If the operation has already been invalidated there is nothing to do.
754   if (!valid)
755     return;
756   auto &liveOperations = getContext()->liveOperations;
757   assert(liveOperations.count(operation.ptr) == 1 &&
758          "destroying operation not in live map");
759   liveOperations.erase(operation.ptr);
760   if (!isAttached()) {
761     mlirOperationDestroy(operation);
762   }
763 }
764 
765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
766                                            MlirOperation operation,
767                                            py::object parentKeepAlive) {
768   auto &liveOperations = contextRef->liveOperations;
769   // Create.
770   PyOperation *unownedOperation =
771       new PyOperation(std::move(contextRef), operation);
772   // Note that the default return value policy on cast is automatic_reference,
773   // which does not take ownership (delete will not be called).
774   // Just be explicit.
775   py::object pyRef =
776       py::cast(unownedOperation, py::return_value_policy::take_ownership);
777   unownedOperation->handle = pyRef;
778   if (parentKeepAlive) {
779     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
780   }
781   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
782   return PyOperationRef(unownedOperation, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
786                                          MlirOperation operation,
787                                          py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   auto it = liveOperations.find(operation.ptr);
790   if (it == liveOperations.end()) {
791     // Create.
792     return createInstance(std::move(contextRef), operation,
793                           std::move(parentKeepAlive));
794   }
795   // Use existing.
796   PyOperation *existing = it->second.second;
797   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
798   return PyOperationRef(existing, std::move(pyRef));
799 }
800 
801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
802                                            MlirOperation operation,
803                                            py::object parentKeepAlive) {
804   auto &liveOperations = contextRef->liveOperations;
805   assert(liveOperations.count(operation.ptr) == 0 &&
806          "cannot create detached operation that already exists");
807   (void)liveOperations;
808 
809   PyOperationRef created = createInstance(std::move(contextRef), operation,
810                                           std::move(parentKeepAlive));
811   created->attached = false;
812   return created;
813 }
814 
815 void PyOperation::checkValid() const {
816   if (!valid) {
817     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
818   }
819 }
820 
821 void PyOperationBase::print(py::object fileObject, bool binary,
822                             llvm::Optional<int64_t> largeElementsLimit,
823                             bool enableDebugInfo, bool prettyDebugInfo,
824                             bool printGenericOpForm, bool useLocalScope) {
825   PyOperation &operation = getOperation();
826   operation.checkValid();
827   if (fileObject.is_none())
828     fileObject = py::module::import("sys").attr("stdout");
829 
830   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
831     fileObject.attr("write")("// Verification failed, printing generic form\n");
832     printGenericOpForm = true;
833   }
834 
835   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
836   if (largeElementsLimit)
837     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
838   if (enableDebugInfo)
839     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
840   if (printGenericOpForm)
841     mlirOpPrintingFlagsPrintGenericOpForm(flags);
842 
843   PyFileAccumulator accum(fileObject, binary);
844   py::gil_scoped_release();
845   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
846                               accum.getUserData());
847   mlirOpPrintingFlagsDestroy(flags);
848 }
849 
850 py::object PyOperationBase::getAsm(bool binary,
851                                    llvm::Optional<int64_t> largeElementsLimit,
852                                    bool enableDebugInfo, bool prettyDebugInfo,
853                                    bool printGenericOpForm,
854                                    bool useLocalScope) {
855   py::object fileObject;
856   if (binary) {
857     fileObject = py::module::import("io").attr("BytesIO")();
858   } else {
859     fileObject = py::module::import("io").attr("StringIO")();
860   }
861   print(fileObject, /*binary=*/binary,
862         /*largeElementsLimit=*/largeElementsLimit,
863         /*enableDebugInfo=*/enableDebugInfo,
864         /*prettyDebugInfo=*/prettyDebugInfo,
865         /*printGenericOpForm=*/printGenericOpForm,
866         /*useLocalScope=*/useLocalScope);
867 
868   return fileObject.attr("getvalue")();
869 }
870 
871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
872   checkValid();
873   if (!isAttached())
874     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
875   MlirOperation operation = mlirOperationGetParentOperation(get());
876   if (mlirOperationIsNull(operation))
877     return {};
878   return PyOperation::forOperation(getContext(), operation);
879 }
880 
881 PyBlock PyOperation::getBlock() {
882   checkValid();
883   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
884   MlirBlock block = mlirOperationGetBlock(get());
885   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
886   assert(parentOperation && "Operation has no parent");
887   return PyBlock{std::move(*parentOperation), block};
888 }
889 
890 py::object PyOperation::getCapsule() {
891   checkValid();
892   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
893 }
894 
895 py::object PyOperation::createFromCapsule(py::object capsule) {
896   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
897   if (mlirOperationIsNull(rawOperation))
898     throw py::error_already_set();
899   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
900   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
901       .releaseObject();
902 }
903 
904 py::object PyOperation::create(
905     std::string name, llvm::Optional<std::vector<PyType *>> results,
906     llvm::Optional<std::vector<PyValue *>> operands,
907     llvm::Optional<py::dict> attributes,
908     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
909     DefaultingPyLocation location, py::object maybeIp) {
910   llvm::SmallVector<MlirValue, 4> mlirOperands;
911   llvm::SmallVector<MlirType, 4> mlirResults;
912   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
913   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
914 
915   // General parameter validation.
916   if (regions < 0)
917     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
918 
919   // Unpack/validate operands.
920   if (operands) {
921     mlirOperands.reserve(operands->size());
922     for (PyValue *operand : *operands) {
923       if (!operand)
924         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
925       mlirOperands.push_back(operand->get());
926     }
927   }
928 
929   // Unpack/validate results.
930   if (results) {
931     mlirResults.reserve(results->size());
932     for (PyType *result : *results) {
933       // TODO: Verify result type originate from the same context.
934       if (!result)
935         throw SetPyError(PyExc_ValueError, "result type cannot be None");
936       mlirResults.push_back(*result);
937     }
938   }
939   // Unpack/validate attributes.
940   if (attributes) {
941     mlirAttributes.reserve(attributes->size());
942     for (auto &it : *attributes) {
943       std::string key;
944       try {
945         key = it.first.cast<std::string>();
946       } catch (py::cast_error &err) {
947         std::string msg = "Invalid attribute key (not a string) when "
948                           "attempting to create the operation \"" +
949                           name + "\" (" + err.what() + ")";
950         throw py::cast_error(msg);
951       }
952       try {
953         auto &attribute = it.second.cast<PyAttribute &>();
954         // TODO: Verify attribute originates from the same context.
955         mlirAttributes.emplace_back(std::move(key), attribute);
956       } catch (py::reference_cast_error &) {
957         // This exception seems thrown when the value is "None".
958         std::string msg =
959             "Found an invalid (`None`?) attribute value for the key \"" + key +
960             "\" when attempting to create the operation \"" + name + "\"";
961         throw py::cast_error(msg);
962       } catch (py::cast_error &err) {
963         std::string msg = "Invalid attribute value for the key \"" + key +
964                           "\" when attempting to create the operation \"" +
965                           name + "\" (" + err.what() + ")";
966         throw py::cast_error(msg);
967       }
968     }
969   }
970   // Unpack/validate successors.
971   if (successors) {
972     mlirSuccessors.reserve(successors->size());
973     for (auto *successor : *successors) {
974       // TODO: Verify successor originate from the same context.
975       if (!successor)
976         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
977       mlirSuccessors.push_back(successor->get());
978     }
979   }
980 
981   // Apply unpacked/validated to the operation state. Beyond this
982   // point, exceptions cannot be thrown or else the state will leak.
983   MlirOperationState state =
984       mlirOperationStateGet(toMlirStringRef(name), location);
985   if (!mlirOperands.empty())
986     mlirOperationStateAddOperands(&state, mlirOperands.size(),
987                                   mlirOperands.data());
988   if (!mlirResults.empty())
989     mlirOperationStateAddResults(&state, mlirResults.size(),
990                                  mlirResults.data());
991   if (!mlirAttributes.empty()) {
992     // Note that the attribute names directly reference bytes in
993     // mlirAttributes, so that vector must not be changed from here
994     // on.
995     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
996     mlirNamedAttributes.reserve(mlirAttributes.size());
997     for (auto &it : mlirAttributes)
998       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
999           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1000                             toMlirStringRef(it.first)),
1001           it.second));
1002     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1003                                     mlirNamedAttributes.data());
1004   }
1005   if (!mlirSuccessors.empty())
1006     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1007                                     mlirSuccessors.data());
1008   if (regions) {
1009     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1010     mlirRegions.resize(regions);
1011     for (int i = 0; i < regions; ++i)
1012       mlirRegions[i] = mlirRegionCreate();
1013     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1014                                       mlirRegions.data());
1015   }
1016 
1017   // Construct the operation.
1018   MlirOperation operation = mlirOperationCreate(&state);
1019   PyOperationRef created =
1020       PyOperation::createDetached(location->getContext(), operation);
1021 
1022   // InsertPoint active?
1023   if (!maybeIp.is(py::cast(false))) {
1024     PyInsertionPoint *ip;
1025     if (maybeIp.is_none()) {
1026       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1027     } else {
1028       ip = py::cast<PyInsertionPoint *>(maybeIp);
1029     }
1030     if (ip)
1031       ip->insert(*created.get());
1032   }
1033 
1034   return created->createOpView();
1035 }
1036 
1037 py::object PyOperation::createOpView() {
1038   checkValid();
1039   MlirIdentifier ident = mlirOperationGetName(get());
1040   MlirStringRef identStr = mlirIdentifierStr(ident);
1041   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1042       StringRef(identStr.data, identStr.length));
1043   if (opViewClass)
1044     return (*opViewClass)(getRef().getObject());
1045   return py::cast(PyOpView(getRef().getObject()));
1046 }
1047 
1048 void PyOperation::erase() {
1049   checkValid();
1050   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1051   // Python reference to a child operation is live. All children should also
1052   // have their `valid` bit set to false.
1053   auto &liveOperations = getContext()->liveOperations;
1054   if (liveOperations.count(operation.ptr))
1055     liveOperations.erase(operation.ptr);
1056   mlirOperationDestroy(operation);
1057   valid = false;
1058 }
1059 
1060 //------------------------------------------------------------------------------
1061 // PyOpView
1062 //------------------------------------------------------------------------------
1063 
1064 py::object
1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1066                        py::list operandList,
1067                        llvm::Optional<py::dict> attributes,
1068                        llvm::Optional<std::vector<PyBlock *>> successors,
1069                        llvm::Optional<int> regions,
1070                        DefaultingPyLocation location, py::object maybeIp) {
1071   PyMlirContextRef context = location->getContext();
1072   // Class level operation construction metadata.
1073   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1074   // Operand and result segment specs are either none, which does no
1075   // variadic unpacking, or a list of ints with segment sizes, where each
1076   // element is either a positive number (typically 1 for a scalar) or -1 to
1077   // indicate that it is derived from the length of the same-indexed operand
1078   // or result (implying that it is a list at that position).
1079   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1080   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1081 
1082   std::vector<uint32_t> operandSegmentLengths;
1083   std::vector<uint32_t> resultSegmentLengths;
1084 
1085   // Validate/determine region count.
1086   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1087   int opMinRegionCount = std::get<0>(opRegionSpec);
1088   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1089   if (!regions) {
1090     regions = opMinRegionCount;
1091   }
1092   if (*regions < opMinRegionCount) {
1093     throw py::value_error(
1094         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1095          llvm::Twine(opMinRegionCount) +
1096          " regions but was built with regions=" + llvm::Twine(*regions))
1097             .str());
1098   }
1099   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1100     throw py::value_error(
1101         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1102          llvm::Twine(opMinRegionCount) +
1103          " regions but was built with regions=" + llvm::Twine(*regions))
1104             .str());
1105   }
1106 
1107   // Unpack results.
1108   std::vector<PyType *> resultTypes;
1109   resultTypes.reserve(resultTypeList.size());
1110   if (resultSegmentSpecObj.is_none()) {
1111     // Non-variadic result unpacking.
1112     for (auto it : llvm::enumerate(resultTypeList)) {
1113       try {
1114         resultTypes.push_back(py::cast<PyType *>(it.value()));
1115         if (!resultTypes.back())
1116           throw py::cast_error();
1117       } catch (py::cast_error &err) {
1118         throw py::value_error((llvm::Twine("Result ") +
1119                                llvm::Twine(it.index()) + " of operation \"" +
1120                                name + "\" must be a Type (" + err.what() + ")")
1121                                   .str());
1122       }
1123     }
1124   } else {
1125     // Sized result unpacking.
1126     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1127     if (resultSegmentSpec.size() != resultTypeList.size()) {
1128       throw py::value_error((llvm::Twine("Operation \"") + name +
1129                              "\" requires " +
1130                              llvm::Twine(resultSegmentSpec.size()) +
1131                              "result segments but was provided " +
1132                              llvm::Twine(resultTypeList.size()))
1133                                 .str());
1134     }
1135     resultSegmentLengths.reserve(resultTypeList.size());
1136     for (auto it :
1137          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1138       int segmentSpec = std::get<1>(it.value());
1139       if (segmentSpec == 1 || segmentSpec == 0) {
1140         // Unpack unary element.
1141         try {
1142           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1143           if (resultType) {
1144             resultTypes.push_back(resultType);
1145             resultSegmentLengths.push_back(1);
1146           } else if (segmentSpec == 0) {
1147             // Allowed to be optional.
1148             resultSegmentLengths.push_back(0);
1149           } else {
1150             throw py::cast_error("was None and result is not optional");
1151           }
1152         } catch (py::cast_error &err) {
1153           throw py::value_error((llvm::Twine("Result ") +
1154                                  llvm::Twine(it.index()) + " of operation \"" +
1155                                  name + "\" must be a Type (" + err.what() +
1156                                  ")")
1157                                     .str());
1158         }
1159       } else if (segmentSpec == -1) {
1160         // Unpack sequence by appending.
1161         try {
1162           if (std::get<0>(it.value()).is_none()) {
1163             // Treat it as an empty list.
1164             resultSegmentLengths.push_back(0);
1165           } else {
1166             // Unpack the list.
1167             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1168             for (py::object segmentItem : segment) {
1169               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1170               if (!resultTypes.back()) {
1171                 throw py::cast_error("contained a None item");
1172               }
1173             }
1174             resultSegmentLengths.push_back(segment.size());
1175           }
1176         } catch (std::exception &err) {
1177           // NOTE: Sloppy to be using a catch-all here, but there are at least
1178           // three different unrelated exceptions that can be thrown in the
1179           // above "casts". Just keep the scope above small and catch them all.
1180           throw py::value_error((llvm::Twine("Result ") +
1181                                  llvm::Twine(it.index()) + " of operation \"" +
1182                                  name + "\" must be a Sequence of Types (" +
1183                                  err.what() + ")")
1184                                     .str());
1185         }
1186       } else {
1187         throw py::value_error("Unexpected segment spec");
1188       }
1189     }
1190   }
1191 
1192   // Unpack operands.
1193   std::vector<PyValue *> operands;
1194   operands.reserve(operands.size());
1195   if (operandSegmentSpecObj.is_none()) {
1196     // Non-sized operand unpacking.
1197     for (auto it : llvm::enumerate(operandList)) {
1198       try {
1199         operands.push_back(py::cast<PyValue *>(it.value()));
1200         if (!operands.back())
1201           throw py::cast_error();
1202       } catch (py::cast_error &err) {
1203         throw py::value_error((llvm::Twine("Operand ") +
1204                                llvm::Twine(it.index()) + " of operation \"" +
1205                                name + "\" must be a Value (" + err.what() + ")")
1206                                   .str());
1207       }
1208     }
1209   } else {
1210     // Sized operand unpacking.
1211     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1212     if (operandSegmentSpec.size() != operandList.size()) {
1213       throw py::value_error((llvm::Twine("Operation \"") + name +
1214                              "\" requires " +
1215                              llvm::Twine(operandSegmentSpec.size()) +
1216                              "operand segments but was provided " +
1217                              llvm::Twine(operandList.size()))
1218                                 .str());
1219     }
1220     operandSegmentLengths.reserve(operandList.size());
1221     for (auto it :
1222          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1223       int segmentSpec = std::get<1>(it.value());
1224       if (segmentSpec == 1 || segmentSpec == 0) {
1225         // Unpack unary element.
1226         try {
1227           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1228           if (operandValue) {
1229             operands.push_back(operandValue);
1230             operandSegmentLengths.push_back(1);
1231           } else if (segmentSpec == 0) {
1232             // Allowed to be optional.
1233             operandSegmentLengths.push_back(0);
1234           } else {
1235             throw py::cast_error("was None and operand is not optional");
1236           }
1237         } catch (py::cast_error &err) {
1238           throw py::value_error((llvm::Twine("Operand ") +
1239                                  llvm::Twine(it.index()) + " of operation \"" +
1240                                  name + "\" must be a Value (" + err.what() +
1241                                  ")")
1242                                     .str());
1243         }
1244       } else if (segmentSpec == -1) {
1245         // Unpack sequence by appending.
1246         try {
1247           if (std::get<0>(it.value()).is_none()) {
1248             // Treat it as an empty list.
1249             operandSegmentLengths.push_back(0);
1250           } else {
1251             // Unpack the list.
1252             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1253             for (py::object segmentItem : segment) {
1254               operands.push_back(py::cast<PyValue *>(segmentItem));
1255               if (!operands.back()) {
1256                 throw py::cast_error("contained a None item");
1257               }
1258             }
1259             operandSegmentLengths.push_back(segment.size());
1260           }
1261         } catch (std::exception &err) {
1262           // NOTE: Sloppy to be using a catch-all here, but there are at least
1263           // three different unrelated exceptions that can be thrown in the
1264           // above "casts". Just keep the scope above small and catch them all.
1265           throw py::value_error((llvm::Twine("Operand ") +
1266                                  llvm::Twine(it.index()) + " of operation \"" +
1267                                  name + "\" must be a Sequence of Values (" +
1268                                  err.what() + ")")
1269                                     .str());
1270         }
1271       } else {
1272         throw py::value_error("Unexpected segment spec");
1273       }
1274     }
1275   }
1276 
1277   // Merge operand/result segment lengths into attributes if needed.
1278   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1279     // Dup.
1280     if (attributes) {
1281       attributes = py::dict(*attributes);
1282     } else {
1283       attributes = py::dict();
1284     }
1285     if (attributes->contains("result_segment_sizes") ||
1286         attributes->contains("operand_segment_sizes")) {
1287       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1288                             "'operand_segment_sizes' attribute is unsupported. "
1289                             "Use Operation.create for such low-level access.");
1290     }
1291 
1292     // Add result_segment_sizes attribute.
1293     if (!resultSegmentLengths.empty()) {
1294       int64_t size = resultSegmentLengths.size();
1295       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1296           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1297           resultSegmentLengths.size(), resultSegmentLengths.data());
1298       (*attributes)["result_segment_sizes"] =
1299           PyAttribute(context, segmentLengthAttr);
1300     }
1301 
1302     // Add operand_segment_sizes attribute.
1303     if (!operandSegmentLengths.empty()) {
1304       int64_t size = operandSegmentLengths.size();
1305       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1306           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1307           operandSegmentLengths.size(), operandSegmentLengths.data());
1308       (*attributes)["operand_segment_sizes"] =
1309           PyAttribute(context, segmentLengthAttr);
1310     }
1311   }
1312 
1313   // Delegate to create.
1314   return PyOperation::create(std::move(name),
1315                              /*results=*/std::move(resultTypes),
1316                              /*operands=*/std::move(operands),
1317                              /*attributes=*/std::move(attributes),
1318                              /*successors=*/std::move(successors),
1319                              /*regions=*/*regions, location, maybeIp);
1320 }
1321 
1322 PyOpView::PyOpView(py::object operationObject)
1323     // Casting through the PyOperationBase base-class and then back to the
1324     // Operation lets us accept any PyOperationBase subclass.
1325     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1326       operationObject(operation.getRef().getObject()) {}
1327 
1328 py::object PyOpView::createRawSubclass(py::object userClass) {
1329   // This is... a little gross. The typical pattern is to have a pure python
1330   // class that extends OpView like:
1331   //   class AddFOp(_cext.ir.OpView):
1332   //     def __init__(self, loc, lhs, rhs):
1333   //       operation = loc.context.create_operation(
1334   //           "addf", lhs, rhs, results=[lhs.type])
1335   //       super().__init__(operation)
1336   //
1337   // I.e. The goal of the user facing type is to provide a nice constructor
1338   // that has complete freedom for the op under construction. This is at odds
1339   // with our other desire to sometimes create this object by just passing an
1340   // operation (to initialize the base class). We could do *arg and **kwargs
1341   // munging to try to make it work, but instead, we synthesize a new class
1342   // on the fly which extends this user class (AddFOp in this example) and
1343   // *give it* the base class's __init__ method, thus bypassing the
1344   // intermediate subclass's __init__ method entirely. While slightly,
1345   // underhanded, this is safe/legal because the type hierarchy has not changed
1346   // (we just added a new leaf) and we aren't mucking around with __new__.
1347   // Typically, this new class will be stored on the original as "_Raw" and will
1348   // be used for casts and other things that need a variant of the class that
1349   // is initialized purely from an operation.
1350   py::object parentMetaclass =
1351       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1352   py::dict attributes;
1353   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1354   // now.
1355   //   auto opViewType = py::type::of<PyOpView>();
1356   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1357   attributes["__init__"] = opViewType.attr("__init__");
1358   py::str origName = userClass.attr("__name__");
1359   py::str newName = py::str("_") + origName;
1360   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1361 }
1362 
1363 //------------------------------------------------------------------------------
1364 // PyInsertionPoint.
1365 //------------------------------------------------------------------------------
1366 
1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1368 
1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1370     : refOperation(beforeOperationBase.getOperation().getRef()),
1371       block((*refOperation)->getBlock()) {}
1372 
1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1374   PyOperation &operation = operationBase.getOperation();
1375   if (operation.isAttached())
1376     throw SetPyError(PyExc_ValueError,
1377                      "Attempt to insert operation that is already attached");
1378   block.getParentOperation()->checkValid();
1379   MlirOperation beforeOp = {nullptr};
1380   if (refOperation) {
1381     // Insert before operation.
1382     (*refOperation)->checkValid();
1383     beforeOp = (*refOperation)->get();
1384   } else {
1385     // Insert at end (before null) is only valid if the block does not
1386     // already end in a known terminator (violating this will cause assertion
1387     // failures later).
1388     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1389       throw py::index_error("Cannot insert operation at the end of a block "
1390                             "that already has a terminator. Did you mean to "
1391                             "use 'InsertionPoint.at_block_terminator(block)' "
1392                             "versus 'InsertionPoint(block)'?");
1393     }
1394   }
1395   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1396   operation.setAttached();
1397 }
1398 
1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1400   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1401   if (mlirOperationIsNull(firstOp)) {
1402     // Just insert at end.
1403     return PyInsertionPoint(block);
1404   }
1405 
1406   // Insert before first op.
1407   PyOperationRef firstOpRef = PyOperation::forOperation(
1408       block.getParentOperation()->getContext(), firstOp);
1409   return PyInsertionPoint{block, std::move(firstOpRef)};
1410 }
1411 
1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1413   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1414   if (mlirOperationIsNull(terminator))
1415     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1416   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1417       block.getParentOperation()->getContext(), terminator);
1418   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1419 }
1420 
1421 py::object PyInsertionPoint::contextEnter() {
1422   return PyThreadContextEntry::pushInsertionPoint(*this);
1423 }
1424 
1425 void PyInsertionPoint::contextExit(pybind11::object excType,
1426                                    pybind11::object excVal,
1427                                    pybind11::object excTb) {
1428   PyThreadContextEntry::popInsertionPoint(*this);
1429 }
1430 
1431 //------------------------------------------------------------------------------
1432 // PyAttribute.
1433 //------------------------------------------------------------------------------
1434 
1435 bool PyAttribute::operator==(const PyAttribute &other) {
1436   return mlirAttributeEqual(attr, other.attr);
1437 }
1438 
1439 py::object PyAttribute::getCapsule() {
1440   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1441 }
1442 
1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1444   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1445   if (mlirAttributeIsNull(rawAttr))
1446     throw py::error_already_set();
1447   return PyAttribute(
1448       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1449 }
1450 
1451 //------------------------------------------------------------------------------
1452 // PyNamedAttribute.
1453 //------------------------------------------------------------------------------
1454 
1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1456     : ownedName(new std::string(std::move(ownedName))) {
1457   namedAttr = mlirNamedAttributeGet(
1458       mlirIdentifierGet(mlirAttributeGetContext(attr),
1459                         toMlirStringRef(*this->ownedName)),
1460       attr);
1461 }
1462 
1463 //------------------------------------------------------------------------------
1464 // PyType.
1465 //------------------------------------------------------------------------------
1466 
1467 bool PyType::operator==(const PyType &other) {
1468   return mlirTypeEqual(type, other.type);
1469 }
1470 
1471 py::object PyType::getCapsule() {
1472   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1473 }
1474 
1475 PyType PyType::createFromCapsule(py::object capsule) {
1476   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1477   if (mlirTypeIsNull(rawType))
1478     throw py::error_already_set();
1479   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1480                 rawType);
1481 }
1482 
1483 //------------------------------------------------------------------------------
1484 // PyValue and subclases.
1485 //------------------------------------------------------------------------------
1486 
1487 pybind11::object PyValue::getCapsule() {
1488   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1489 }
1490 
1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1492   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1493   if (mlirValueIsNull(value))
1494     throw py::error_already_set();
1495   MlirOperation owner;
1496   if (mlirValueIsAOpResult(value))
1497     owner = mlirOpResultGetOwner(value);
1498   if (mlirValueIsABlockArgument(value))
1499     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1500   if (mlirOperationIsNull(owner))
1501     throw py::error_already_set();
1502   MlirContext ctx = mlirOperationGetContext(owner);
1503   PyOperationRef ownerRef =
1504       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1505   return PyValue(ownerRef, value);
1506 }
1507 
1508 namespace {
1509 /// CRTP base class for Python MLIR values that subclass Value and should be
1510 /// castable from it. The value hierarchy is one level deep and is not supposed
1511 /// to accommodate other levels unless core MLIR changes.
1512 template <typename DerivedTy>
1513 class PyConcreteValue : public PyValue {
1514 public:
1515   // Derived classes must define statics for:
1516   //   IsAFunctionTy isaFunction
1517   //   const char *pyClassName
1518   // and redefine bindDerived.
1519   using ClassTy = py::class_<DerivedTy, PyValue>;
1520   using IsAFunctionTy = bool (*)(MlirValue);
1521 
1522   PyConcreteValue() = default;
1523   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1524       : PyValue(operationRef, value) {}
1525   PyConcreteValue(PyValue &orig)
1526       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1527 
1528   /// Attempts to cast the original value to the derived type and throws on
1529   /// type mismatches.
1530   static MlirValue castFrom(PyValue &orig) {
1531     if (!DerivedTy::isaFunction(orig.get())) {
1532       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1533       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1534                                              DerivedTy::pyClassName +
1535                                              " (from " + origRepr + ")");
1536     }
1537     return orig.get();
1538   }
1539 
1540   /// Binds the Python module objects to functions of this class.
1541   static void bind(py::module &m) {
1542     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1543     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1544     DerivedTy::bindDerived(cls);
1545   }
1546 
1547   /// Implemented by derived classes to add methods to the Python subclass.
1548   static void bindDerived(ClassTy &m) {}
1549 };
1550 
1551 /// Python wrapper for MlirBlockArgument.
1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1553 public:
1554   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1555   static constexpr const char *pyClassName = "BlockArgument";
1556   using PyConcreteValue::PyConcreteValue;
1557 
1558   static void bindDerived(ClassTy &c) {
1559     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1560       return PyBlock(self.getParentOperation(),
1561                      mlirBlockArgumentGetOwner(self.get()));
1562     });
1563     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1564       return mlirBlockArgumentGetArgNumber(self.get());
1565     });
1566     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1567       return mlirBlockArgumentSetType(self.get(), type);
1568     });
1569   }
1570 };
1571 
1572 /// Python wrapper for MlirOpResult.
1573 class PyOpResult : public PyConcreteValue<PyOpResult> {
1574 public:
1575   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1576   static constexpr const char *pyClassName = "OpResult";
1577   using PyConcreteValue::PyConcreteValue;
1578 
1579   static void bindDerived(ClassTy &c) {
1580     c.def_property_readonly("owner", [](PyOpResult &self) {
1581       assert(
1582           mlirOperationEqual(self.getParentOperation()->get(),
1583                              mlirOpResultGetOwner(self.get())) &&
1584           "expected the owner of the value in Python to match that in the IR");
1585       return self.getParentOperation().getObject();
1586     });
1587     c.def_property_readonly("result_number", [](PyOpResult &self) {
1588       return mlirOpResultGetResultNumber(self.get());
1589     });
1590   }
1591 };
1592 
1593 /// A list of block arguments. Internally, these are stored as consecutive
1594 /// elements, random access is cheap. The argument list is associated with the
1595 /// operation that contains the block (detached blocks are not allowed in
1596 /// Python bindings) and extends its lifetime.
1597 class PyBlockArgumentList
1598     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1599 public:
1600   static constexpr const char *pyClassName = "BlockArgumentList";
1601 
1602   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1603                       intptr_t startIndex = 0, intptr_t length = -1,
1604                       intptr_t step = 1)
1605       : Sliceable(startIndex,
1606                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1607                   step),
1608         operation(std::move(operation)), block(block) {}
1609 
1610   /// Returns the number of arguments in the list.
1611   intptr_t getNumElements() {
1612     operation->checkValid();
1613     return mlirBlockGetNumArguments(block);
1614   }
1615 
1616   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1617   PyBlockArgument getElement(intptr_t pos) {
1618     MlirValue argument = mlirBlockGetArgument(block, pos);
1619     return PyBlockArgument(operation, argument);
1620   }
1621 
1622   /// Returns a sublist of this list.
1623   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1624                             intptr_t step) {
1625     return PyBlockArgumentList(operation, block, startIndex, length, step);
1626   }
1627 
1628 private:
1629   PyOperationRef operation;
1630   MlirBlock block;
1631 };
1632 
1633 /// A list of operation operands. Internally, these are stored as consecutive
1634 /// elements, random access is cheap. The result list is associated with the
1635 /// operation whose results these are, and extends the lifetime of this
1636 /// operation.
1637 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1638 public:
1639   static constexpr const char *pyClassName = "OpOperandList";
1640 
1641   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1642                   intptr_t length = -1, intptr_t step = 1)
1643       : Sliceable(startIndex,
1644                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1645                                : length,
1646                   step),
1647         operation(operation) {}
1648 
1649   intptr_t getNumElements() {
1650     operation->checkValid();
1651     return mlirOperationGetNumOperands(operation->get());
1652   }
1653 
1654   PyValue getElement(intptr_t pos) {
1655     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1656     MlirOperation owner;
1657     if (mlirValueIsAOpResult(operand))
1658       owner = mlirOpResultGetOwner(operand);
1659     else if (mlirValueIsABlockArgument(operand))
1660       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1661     else
1662       assert(false && "Value must be an block arg or op result.");
1663     PyOperationRef pyOwner =
1664         PyOperation::forOperation(operation->getContext(), owner);
1665     return PyValue(pyOwner, operand);
1666   }
1667 
1668   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1669     return PyOpOperandList(operation, startIndex, length, step);
1670   }
1671 
1672   void dunderSetItem(intptr_t index, PyValue value) {
1673     index = wrapIndex(index);
1674     mlirOperationSetOperand(operation->get(), index, value.get());
1675   }
1676 
1677   static void bindDerived(ClassTy &c) {
1678     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1679   }
1680 
1681 private:
1682   PyOperationRef operation;
1683 };
1684 
1685 /// A list of operation results. Internally, these are stored as consecutive
1686 /// elements, random access is cheap. The result list is associated with the
1687 /// operation whose results these are, and extends the lifetime of this
1688 /// operation.
1689 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1690 public:
1691   static constexpr const char *pyClassName = "OpResultList";
1692 
1693   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1694                  intptr_t length = -1, intptr_t step = 1)
1695       : Sliceable(startIndex,
1696                   length == -1 ? mlirOperationGetNumResults(operation->get())
1697                                : length,
1698                   step),
1699         operation(operation) {}
1700 
1701   intptr_t getNumElements() {
1702     operation->checkValid();
1703     return mlirOperationGetNumResults(operation->get());
1704   }
1705 
1706   PyOpResult getElement(intptr_t index) {
1707     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1708     return PyOpResult(value);
1709   }
1710 
1711   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1712     return PyOpResultList(operation, startIndex, length, step);
1713   }
1714 
1715 private:
1716   PyOperationRef operation;
1717 };
1718 
1719 /// A list of operation attributes. Can be indexed by name, producing
1720 /// attributes, or by index, producing named attributes.
1721 class PyOpAttributeMap {
1722 public:
1723   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1724 
1725   PyAttribute dunderGetItemNamed(const std::string &name) {
1726     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1727                                                          toMlirStringRef(name));
1728     if (mlirAttributeIsNull(attr)) {
1729       throw SetPyError(PyExc_KeyError,
1730                        "attempt to access a non-existent attribute");
1731     }
1732     return PyAttribute(operation->getContext(), attr);
1733   }
1734 
1735   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1736     if (index < 0 || index >= dunderLen()) {
1737       throw SetPyError(PyExc_IndexError,
1738                        "attempt to access out of bounds attribute");
1739     }
1740     MlirNamedAttribute namedAttr =
1741         mlirOperationGetAttribute(operation->get(), index);
1742     return PyNamedAttribute(
1743         namedAttr.attribute,
1744         std::string(mlirIdentifierStr(namedAttr.name).data));
1745   }
1746 
1747   void dunderSetItem(const std::string &name, PyAttribute attr) {
1748     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1749                                     attr);
1750   }
1751 
1752   void dunderDelItem(const std::string &name) {
1753     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1754                                                      toMlirStringRef(name));
1755     if (!removed)
1756       throw SetPyError(PyExc_KeyError,
1757                        "attempt to delete a non-existent attribute");
1758   }
1759 
1760   intptr_t dunderLen() {
1761     return mlirOperationGetNumAttributes(operation->get());
1762   }
1763 
1764   bool dunderContains(const std::string &name) {
1765     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1766         operation->get(), toMlirStringRef(name)));
1767   }
1768 
1769   static void bind(py::module &m) {
1770     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1771         .def("__contains__", &PyOpAttributeMap::dunderContains)
1772         .def("__len__", &PyOpAttributeMap::dunderLen)
1773         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1774         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1775         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1776         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1777   }
1778 
1779 private:
1780   PyOperationRef operation;
1781 };
1782 
1783 } // end namespace
1784 
1785 //------------------------------------------------------------------------------
1786 // Populates the core exports of the 'ir' submodule.
1787 //------------------------------------------------------------------------------
1788 
1789 void mlir::python::populateIRCore(py::module &m) {
1790   //----------------------------------------------------------------------------
1791   // Mapping of MlirContext.
1792   //----------------------------------------------------------------------------
1793   py::class_<PyMlirContext>(m, "Context", py::module_local())
1794       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1795       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1796       .def("_get_context_again",
1797            [](PyMlirContext &self) {
1798              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1799              return ref.releaseObject();
1800            })
1801       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1802       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1803       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1804                              &PyMlirContext::getCapsule)
1805       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1806       .def("__enter__", &PyMlirContext::contextEnter)
1807       .def("__exit__", &PyMlirContext::contextExit)
1808       .def_property_readonly_static(
1809           "current",
1810           [](py::object & /*class*/) {
1811             auto *context = PyThreadContextEntry::getDefaultContext();
1812             if (!context)
1813               throw SetPyError(PyExc_ValueError, "No current Context");
1814             return context;
1815           },
1816           "Gets the Context bound to the current thread or raises ValueError")
1817       .def_property_readonly(
1818           "dialects",
1819           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1820           "Gets a container for accessing dialects by name")
1821       .def_property_readonly(
1822           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1823           "Alias for 'dialect'")
1824       .def(
1825           "get_dialect_descriptor",
1826           [=](PyMlirContext &self, std::string &name) {
1827             MlirDialect dialect = mlirContextGetOrLoadDialect(
1828                 self.get(), {name.data(), name.size()});
1829             if (mlirDialectIsNull(dialect)) {
1830               throw SetPyError(PyExc_ValueError,
1831                                Twine("Dialect '") + name + "' not found");
1832             }
1833             return PyDialectDescriptor(self.getRef(), dialect);
1834           },
1835           "Gets or loads a dialect by name, returning its descriptor object")
1836       .def_property(
1837           "allow_unregistered_dialects",
1838           [](PyMlirContext &self) -> bool {
1839             return mlirContextGetAllowUnregisteredDialects(self.get());
1840           },
1841           [](PyMlirContext &self, bool value) {
1842             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1843           })
1844       .def("enable_multithreading",
1845            [](PyMlirContext &self, bool enable) {
1846              mlirContextEnableMultithreading(self.get(), enable);
1847            })
1848       .def("is_registered_operation",
1849            [](PyMlirContext &self, std::string &name) {
1850              return mlirContextIsRegisteredOperation(
1851                  self.get(), MlirStringRef{name.data(), name.size()});
1852            });
1853 
1854   //----------------------------------------------------------------------------
1855   // Mapping of PyDialectDescriptor
1856   //----------------------------------------------------------------------------
1857   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1858       .def_property_readonly("namespace",
1859                              [](PyDialectDescriptor &self) {
1860                                MlirStringRef ns =
1861                                    mlirDialectGetNamespace(self.get());
1862                                return py::str(ns.data, ns.length);
1863                              })
1864       .def("__repr__", [](PyDialectDescriptor &self) {
1865         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1866         std::string repr("<DialectDescriptor ");
1867         repr.append(ns.data, ns.length);
1868         repr.append(">");
1869         return repr;
1870       });
1871 
1872   //----------------------------------------------------------------------------
1873   // Mapping of PyDialects
1874   //----------------------------------------------------------------------------
1875   py::class_<PyDialects>(m, "Dialects", py::module_local())
1876       .def("__getitem__",
1877            [=](PyDialects &self, std::string keyName) {
1878              MlirDialect dialect =
1879                  self.getDialectForKey(keyName, /*attrError=*/false);
1880              py::object descriptor =
1881                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1882              return createCustomDialectWrapper(keyName, std::move(descriptor));
1883            })
1884       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1885         MlirDialect dialect =
1886             self.getDialectForKey(attrName, /*attrError=*/true);
1887         py::object descriptor =
1888             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1889         return createCustomDialectWrapper(attrName, std::move(descriptor));
1890       });
1891 
1892   //----------------------------------------------------------------------------
1893   // Mapping of PyDialect
1894   //----------------------------------------------------------------------------
1895   py::class_<PyDialect>(m, "Dialect", py::module_local())
1896       .def(py::init<py::object>(), "descriptor")
1897       .def_property_readonly(
1898           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1899       .def("__repr__", [](py::object self) {
1900         auto clazz = self.attr("__class__");
1901         return py::str("<Dialect ") +
1902                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1903                clazz.attr("__module__") + py::str(".") +
1904                clazz.attr("__name__") + py::str(")>");
1905       });
1906 
1907   //----------------------------------------------------------------------------
1908   // Mapping of Location
1909   //----------------------------------------------------------------------------
1910   py::class_<PyLocation>(m, "Location", py::module_local())
1911       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1912       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1913       .def("__enter__", &PyLocation::contextEnter)
1914       .def("__exit__", &PyLocation::contextExit)
1915       .def("__eq__",
1916            [](PyLocation &self, PyLocation &other) -> bool {
1917              return mlirLocationEqual(self, other);
1918            })
1919       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1920       .def_property_readonly_static(
1921           "current",
1922           [](py::object & /*class*/) {
1923             auto *loc = PyThreadContextEntry::getDefaultLocation();
1924             if (!loc)
1925               throw SetPyError(PyExc_ValueError, "No current Location");
1926             return loc;
1927           },
1928           "Gets the Location bound to the current thread or raises ValueError")
1929       .def_static(
1930           "unknown",
1931           [](DefaultingPyMlirContext context) {
1932             return PyLocation(context->getRef(),
1933                               mlirLocationUnknownGet(context->get()));
1934           },
1935           py::arg("context") = py::none(),
1936           "Gets a Location representing an unknown location")
1937       .def_static(
1938           "file",
1939           [](std::string filename, int line, int col,
1940              DefaultingPyMlirContext context) {
1941             return PyLocation(
1942                 context->getRef(),
1943                 mlirLocationFileLineColGet(
1944                     context->get(), toMlirStringRef(filename), line, col));
1945           },
1946           py::arg("filename"), py::arg("line"), py::arg("col"),
1947           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1948       .def_property_readonly(
1949           "context",
1950           [](PyLocation &self) { return self.getContext().getObject(); },
1951           "Context that owns the Location")
1952       .def("__repr__", [](PyLocation &self) {
1953         PyPrintAccumulator printAccum;
1954         mlirLocationPrint(self, printAccum.getCallback(),
1955                           printAccum.getUserData());
1956         return printAccum.join();
1957       });
1958 
1959   //----------------------------------------------------------------------------
1960   // Mapping of Module
1961   //----------------------------------------------------------------------------
1962   py::class_<PyModule>(m, "Module", py::module_local())
1963       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1964       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1965       .def_static(
1966           "parse",
1967           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1968             MlirModule module = mlirModuleCreateParse(
1969                 context->get(), toMlirStringRef(moduleAsm));
1970             // TODO: Rework error reporting once diagnostic engine is exposed
1971             // in C API.
1972             if (mlirModuleIsNull(module)) {
1973               throw SetPyError(
1974                   PyExc_ValueError,
1975                   "Unable to parse module assembly (see diagnostics)");
1976             }
1977             return PyModule::forModule(module).releaseObject();
1978           },
1979           py::arg("asm"), py::arg("context") = py::none(),
1980           kModuleParseDocstring)
1981       .def_static(
1982           "create",
1983           [](DefaultingPyLocation loc) {
1984             MlirModule module = mlirModuleCreateEmpty(loc);
1985             return PyModule::forModule(module).releaseObject();
1986           },
1987           py::arg("loc") = py::none(), "Creates an empty module")
1988       .def_property_readonly(
1989           "context",
1990           [](PyModule &self) { return self.getContext().getObject(); },
1991           "Context that created the Module")
1992       .def_property_readonly(
1993           "operation",
1994           [](PyModule &self) {
1995             return PyOperation::forOperation(self.getContext(),
1996                                              mlirModuleGetOperation(self.get()),
1997                                              self.getRef().releaseObject())
1998                 .releaseObject();
1999           },
2000           "Accesses the module as an operation")
2001       .def_property_readonly(
2002           "body",
2003           [](PyModule &self) {
2004             PyOperationRef module_op = PyOperation::forOperation(
2005                 self.getContext(), mlirModuleGetOperation(self.get()),
2006                 self.getRef().releaseObject());
2007             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2008             return returnBlock;
2009           },
2010           "Return the block for this module")
2011       .def(
2012           "dump",
2013           [](PyModule &self) {
2014             mlirOperationDump(mlirModuleGetOperation(self.get()));
2015           },
2016           kDumpDocstring)
2017       .def(
2018           "__str__",
2019           [](PyModule &self) {
2020             MlirOperation operation = mlirModuleGetOperation(self.get());
2021             PyPrintAccumulator printAccum;
2022             mlirOperationPrint(operation, printAccum.getCallback(),
2023                                printAccum.getUserData());
2024             return printAccum.join();
2025           },
2026           kOperationStrDunderDocstring);
2027 
2028   //----------------------------------------------------------------------------
2029   // Mapping of Operation.
2030   //----------------------------------------------------------------------------
2031   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2032       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2033                              [](PyOperationBase &self) {
2034                                return self.getOperation().getCapsule();
2035                              })
2036       .def("__eq__",
2037            [](PyOperationBase &self, PyOperationBase &other) {
2038              return &self.getOperation() == &other.getOperation();
2039            })
2040       .def("__eq__",
2041            [](PyOperationBase &self, py::object other) { return false; })
2042       .def_property_readonly("attributes",
2043                              [](PyOperationBase &self) {
2044                                return PyOpAttributeMap(
2045                                    self.getOperation().getRef());
2046                              })
2047       .def_property_readonly("operands",
2048                              [](PyOperationBase &self) {
2049                                return PyOpOperandList(
2050                                    self.getOperation().getRef());
2051                              })
2052       .def_property_readonly("regions",
2053                              [](PyOperationBase &self) {
2054                                return PyRegionList(
2055                                    self.getOperation().getRef());
2056                              })
2057       .def_property_readonly(
2058           "results",
2059           [](PyOperationBase &self) {
2060             return PyOpResultList(self.getOperation().getRef());
2061           },
2062           "Returns the list of Operation results.")
2063       .def_property_readonly(
2064           "result",
2065           [](PyOperationBase &self) {
2066             auto &operation = self.getOperation();
2067             auto numResults = mlirOperationGetNumResults(operation);
2068             if (numResults != 1) {
2069               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2070               throw SetPyError(
2071                   PyExc_ValueError,
2072                   Twine("Cannot call .result on operation ") +
2073                       StringRef(name.data, name.length) + " which has " +
2074                       Twine(numResults) +
2075                       " results (it is only valid for operations with a "
2076                       "single result)");
2077             }
2078             return PyOpResult(operation.getRef(),
2079                               mlirOperationGetResult(operation, 0));
2080           },
2081           "Shortcut to get an op result if it has only one (throws an error "
2082           "otherwise).")
2083       .def("__iter__",
2084            [](PyOperationBase &self) {
2085              return PyRegionIterator(self.getOperation().getRef());
2086            })
2087       .def(
2088           "__str__",
2089           [](PyOperationBase &self) {
2090             return self.getAsm(/*binary=*/false,
2091                                /*largeElementsLimit=*/llvm::None,
2092                                /*enableDebugInfo=*/false,
2093                                /*prettyDebugInfo=*/false,
2094                                /*printGenericOpForm=*/false,
2095                                /*useLocalScope=*/false);
2096           },
2097           "Returns the assembly form of the operation.")
2098       .def("print", &PyOperationBase::print,
2099            // Careful: Lots of arguments must match up with print method.
2100            py::arg("file") = py::none(), py::arg("binary") = false,
2101            py::arg("large_elements_limit") = py::none(),
2102            py::arg("enable_debug_info") = false,
2103            py::arg("pretty_debug_info") = false,
2104            py::arg("print_generic_op_form") = false,
2105            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2106       .def("get_asm", &PyOperationBase::getAsm,
2107            // Careful: Lots of arguments must match up with get_asm method.
2108            py::arg("binary") = false,
2109            py::arg("large_elements_limit") = py::none(),
2110            py::arg("enable_debug_info") = false,
2111            py::arg("pretty_debug_info") = false,
2112            py::arg("print_generic_op_form") = false,
2113            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2114       .def(
2115           "verify",
2116           [](PyOperationBase &self) {
2117             return mlirOperationVerify(self.getOperation());
2118           },
2119           "Verify the operation and return true if it passes, false if it "
2120           "fails.");
2121 
2122   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2123       .def_static("create", &PyOperation::create, py::arg("name"),
2124                   py::arg("results") = py::none(),
2125                   py::arg("operands") = py::none(),
2126                   py::arg("attributes") = py::none(),
2127                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2128                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2129                   kOperationCreateDocstring)
2130       .def_property_readonly("parent",
2131                              [](PyOperation &self) -> py::object {
2132                                auto parent = self.getParentOperation();
2133                                if (parent)
2134                                  return parent->getObject();
2135                                return py::none();
2136                              })
2137       .def("erase", &PyOperation::erase)
2138       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2139                              &PyOperation::getCapsule)
2140       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2141       .def_property_readonly("name",
2142                              [](PyOperation &self) {
2143                                self.checkValid();
2144                                MlirOperation operation = self.get();
2145                                MlirStringRef name = mlirIdentifierStr(
2146                                    mlirOperationGetName(operation));
2147                                return py::str(name.data, name.length);
2148                              })
2149       .def_property_readonly(
2150           "context",
2151           [](PyOperation &self) {
2152             self.checkValid();
2153             return self.getContext().getObject();
2154           },
2155           "Context that owns the Operation")
2156       .def_property_readonly("opview", &PyOperation::createOpView);
2157 
2158   auto opViewClass =
2159       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2160           .def(py::init<py::object>())
2161           .def_property_readonly("operation", &PyOpView::getOperationObject)
2162           .def_property_readonly(
2163               "context",
2164               [](PyOpView &self) {
2165                 return self.getOperation().getContext().getObject();
2166               },
2167               "Context that owns the Operation")
2168           .def("__str__", [](PyOpView &self) {
2169             return py::str(self.getOperationObject());
2170           });
2171   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2172   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2173   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2174   opViewClass.attr("build_generic") = classmethod(
2175       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2176       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2177       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2178       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2179       "Builds a specific, generated OpView based on class level attributes.");
2180 
2181   //----------------------------------------------------------------------------
2182   // Mapping of PyRegion.
2183   //----------------------------------------------------------------------------
2184   py::class_<PyRegion>(m, "Region", py::module_local())
2185       .def_property_readonly(
2186           "blocks",
2187           [](PyRegion &self) {
2188             return PyBlockList(self.getParentOperation(), self.get());
2189           },
2190           "Returns a forward-optimized sequence of blocks.")
2191       .def(
2192           "__iter__",
2193           [](PyRegion &self) {
2194             self.checkValid();
2195             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2196             return PyBlockIterator(self.getParentOperation(), firstBlock);
2197           },
2198           "Iterates over blocks in the region.")
2199       .def("__eq__",
2200            [](PyRegion &self, PyRegion &other) {
2201              return self.get().ptr == other.get().ptr;
2202            })
2203       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2204 
2205   //----------------------------------------------------------------------------
2206   // Mapping of PyBlock.
2207   //----------------------------------------------------------------------------
2208   py::class_<PyBlock>(m, "Block", py::module_local())
2209       .def_property_readonly(
2210           "owner",
2211           [](PyBlock &self) {
2212             return self.getParentOperation()->createOpView();
2213           },
2214           "Returns the owning operation of this block.")
2215       .def_property_readonly(
2216           "region",
2217           [](PyBlock &self) {
2218             MlirRegion region = mlirBlockGetParentRegion(self.get());
2219             return PyRegion(self.getParentOperation(), region);
2220           },
2221           "Returns the owning region of this block.")
2222       .def_property_readonly(
2223           "arguments",
2224           [](PyBlock &self) {
2225             return PyBlockArgumentList(self.getParentOperation(), self.get());
2226           },
2227           "Returns a list of block arguments.")
2228       .def_property_readonly(
2229           "operations",
2230           [](PyBlock &self) {
2231             return PyOperationList(self.getParentOperation(), self.get());
2232           },
2233           "Returns a forward-optimized sequence of operations.")
2234       .def(
2235           "create_before",
2236           [](PyBlock &self, py::args pyArgTypes) {
2237             self.checkValid();
2238             llvm::SmallVector<MlirType, 4> argTypes;
2239             argTypes.reserve(pyArgTypes.size());
2240             for (auto &pyArg : pyArgTypes) {
2241               argTypes.push_back(pyArg.cast<PyType &>());
2242             }
2243 
2244             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2245             MlirRegion region = mlirBlockGetParentRegion(self.get());
2246             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2247             return PyBlock(self.getParentOperation(), block);
2248           },
2249           "Creates and returns a new Block before this block "
2250           "(with given argument types).")
2251       .def(
2252           "create_after",
2253           [](PyBlock &self, py::args pyArgTypes) {
2254             self.checkValid();
2255             llvm::SmallVector<MlirType, 4> argTypes;
2256             argTypes.reserve(pyArgTypes.size());
2257             for (auto &pyArg : pyArgTypes) {
2258               argTypes.push_back(pyArg.cast<PyType &>());
2259             }
2260 
2261             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2262             MlirRegion region = mlirBlockGetParentRegion(self.get());
2263             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2264             return PyBlock(self.getParentOperation(), block);
2265           },
2266           "Creates and returns a new Block after this block "
2267           "(with given argument types).")
2268       .def(
2269           "__iter__",
2270           [](PyBlock &self) {
2271             self.checkValid();
2272             MlirOperation firstOperation =
2273                 mlirBlockGetFirstOperation(self.get());
2274             return PyOperationIterator(self.getParentOperation(),
2275                                        firstOperation);
2276           },
2277           "Iterates over operations in the block.")
2278       .def("__eq__",
2279            [](PyBlock &self, PyBlock &other) {
2280              return self.get().ptr == other.get().ptr;
2281            })
2282       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2283       .def(
2284           "__str__",
2285           [](PyBlock &self) {
2286             self.checkValid();
2287             PyPrintAccumulator printAccum;
2288             mlirBlockPrint(self.get(), printAccum.getCallback(),
2289                            printAccum.getUserData());
2290             return printAccum.join();
2291           },
2292           "Returns the assembly form of the block.");
2293 
2294   //----------------------------------------------------------------------------
2295   // Mapping of PyInsertionPoint.
2296   //----------------------------------------------------------------------------
2297 
2298   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2299       .def(py::init<PyBlock &>(), py::arg("block"),
2300            "Inserts after the last operation but still inside the block.")
2301       .def("__enter__", &PyInsertionPoint::contextEnter)
2302       .def("__exit__", &PyInsertionPoint::contextExit)
2303       .def_property_readonly_static(
2304           "current",
2305           [](py::object & /*class*/) {
2306             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2307             if (!ip)
2308               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2309             return ip;
2310           },
2311           "Gets the InsertionPoint bound to the current thread or raises "
2312           "ValueError if none has been set")
2313       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2314            "Inserts before a referenced operation.")
2315       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2316                   py::arg("block"), "Inserts at the beginning of the block.")
2317       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2318                   py::arg("block"), "Inserts before the block terminator.")
2319       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2320            "Inserts an operation.")
2321       .def_property_readonly(
2322           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2323           "Returns the block that this InsertionPoint points to.");
2324 
2325   //----------------------------------------------------------------------------
2326   // Mapping of PyAttribute.
2327   //----------------------------------------------------------------------------
2328   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2329       // Delegate to the PyAttribute copy constructor, which will also lifetime
2330       // extend the backing context which owns the MlirAttribute.
2331       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2332            "Casts the passed attribute to the generic Attribute")
2333       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2334                              &PyAttribute::getCapsule)
2335       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2336       .def_static(
2337           "parse",
2338           [](std::string attrSpec, DefaultingPyMlirContext context) {
2339             MlirAttribute type = mlirAttributeParseGet(
2340                 context->get(), toMlirStringRef(attrSpec));
2341             // TODO: Rework error reporting once diagnostic engine is exposed
2342             // in C API.
2343             if (mlirAttributeIsNull(type)) {
2344               throw SetPyError(PyExc_ValueError,
2345                                Twine("Unable to parse attribute: '") +
2346                                    attrSpec + "'");
2347             }
2348             return PyAttribute(context->getRef(), type);
2349           },
2350           py::arg("asm"), py::arg("context") = py::none(),
2351           "Parses an attribute from an assembly form")
2352       .def_property_readonly(
2353           "context",
2354           [](PyAttribute &self) { return self.getContext().getObject(); },
2355           "Context that owns the Attribute")
2356       .def_property_readonly("type",
2357                              [](PyAttribute &self) {
2358                                return PyType(self.getContext()->getRef(),
2359                                              mlirAttributeGetType(self));
2360                              })
2361       .def(
2362           "get_named",
2363           [](PyAttribute &self, std::string name) {
2364             return PyNamedAttribute(self, std::move(name));
2365           },
2366           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2367       .def("__eq__",
2368            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2369       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2370       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2371       .def(
2372           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2373           kDumpDocstring)
2374       .def(
2375           "__str__",
2376           [](PyAttribute &self) {
2377             PyPrintAccumulator printAccum;
2378             mlirAttributePrint(self, printAccum.getCallback(),
2379                                printAccum.getUserData());
2380             return printAccum.join();
2381           },
2382           "Returns the assembly form of the Attribute.")
2383       .def("__repr__", [](PyAttribute &self) {
2384         // Generally, assembly formats are not printed for __repr__ because
2385         // this can cause exceptionally long debug output and exceptions.
2386         // However, attribute values are generally considered useful and are
2387         // printed. This may need to be re-evaluated if debug dumps end up
2388         // being excessive.
2389         PyPrintAccumulator printAccum;
2390         printAccum.parts.append("Attribute(");
2391         mlirAttributePrint(self, printAccum.getCallback(),
2392                            printAccum.getUserData());
2393         printAccum.parts.append(")");
2394         return printAccum.join();
2395       });
2396 
2397   //----------------------------------------------------------------------------
2398   // Mapping of PyNamedAttribute
2399   //----------------------------------------------------------------------------
2400   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2401       .def("__repr__",
2402            [](PyNamedAttribute &self) {
2403              PyPrintAccumulator printAccum;
2404              printAccum.parts.append("NamedAttribute(");
2405              printAccum.parts.append(
2406                  mlirIdentifierStr(self.namedAttr.name).data);
2407              printAccum.parts.append("=");
2408              mlirAttributePrint(self.namedAttr.attribute,
2409                                 printAccum.getCallback(),
2410                                 printAccum.getUserData());
2411              printAccum.parts.append(")");
2412              return printAccum.join();
2413            })
2414       .def_property_readonly(
2415           "name",
2416           [](PyNamedAttribute &self) {
2417             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2418                            mlirIdentifierStr(self.namedAttr.name).length);
2419           },
2420           "The name of the NamedAttribute binding")
2421       .def_property_readonly(
2422           "attr",
2423           [](PyNamedAttribute &self) {
2424             // TODO: When named attribute is removed/refactored, also remove
2425             // this constructor (it does an inefficient table lookup).
2426             auto contextRef = PyMlirContext::forContext(
2427                 mlirAttributeGetContext(self.namedAttr.attribute));
2428             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2429           },
2430           py::keep_alive<0, 1>(),
2431           "The underlying generic attribute of the NamedAttribute binding");
2432 
2433   //----------------------------------------------------------------------------
2434   // Mapping of PyType.
2435   //----------------------------------------------------------------------------
2436   py::class_<PyType>(m, "Type", py::module_local())
2437       // Delegate to the PyType copy constructor, which will also lifetime
2438       // extend the backing context which owns the MlirType.
2439       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2440            "Casts the passed type to the generic Type")
2441       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2442       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2443       .def_static(
2444           "parse",
2445           [](std::string typeSpec, DefaultingPyMlirContext context) {
2446             MlirType type =
2447                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2448             // TODO: Rework error reporting once diagnostic engine is exposed
2449             // in C API.
2450             if (mlirTypeIsNull(type)) {
2451               throw SetPyError(PyExc_ValueError,
2452                                Twine("Unable to parse type: '") + typeSpec +
2453                                    "'");
2454             }
2455             return PyType(context->getRef(), type);
2456           },
2457           py::arg("asm"), py::arg("context") = py::none(),
2458           kContextParseTypeDocstring)
2459       .def_property_readonly(
2460           "context", [](PyType &self) { return self.getContext().getObject(); },
2461           "Context that owns the Type")
2462       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2463       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2464       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2465       .def(
2466           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2467       .def(
2468           "__str__",
2469           [](PyType &self) {
2470             PyPrintAccumulator printAccum;
2471             mlirTypePrint(self, printAccum.getCallback(),
2472                           printAccum.getUserData());
2473             return printAccum.join();
2474           },
2475           "Returns the assembly form of the type.")
2476       .def("__repr__", [](PyType &self) {
2477         // Generally, assembly formats are not printed for __repr__ because
2478         // this can cause exceptionally long debug output and exceptions.
2479         // However, types are an exception as they typically have compact
2480         // assembly forms and printing them is useful.
2481         PyPrintAccumulator printAccum;
2482         printAccum.parts.append("Type(");
2483         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2484         printAccum.parts.append(")");
2485         return printAccum.join();
2486       });
2487 
2488   //----------------------------------------------------------------------------
2489   // Mapping of Value.
2490   //----------------------------------------------------------------------------
2491   py::class_<PyValue>(m, "Value", py::module_local())
2492       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2493       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2494       .def_property_readonly(
2495           "context",
2496           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2497           "Context in which the value lives.")
2498       .def(
2499           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2500           kDumpDocstring)
2501       .def_property_readonly(
2502           "owner",
2503           [](PyValue &self) {
2504             assert(mlirOperationEqual(self.getParentOperation()->get(),
2505                                       mlirOpResultGetOwner(self.get())) &&
2506                    "expected the owner of the value in Python to match that in "
2507                    "the IR");
2508             return self.getParentOperation().getObject();
2509           })
2510       .def("__eq__",
2511            [](PyValue &self, PyValue &other) {
2512              return self.get().ptr == other.get().ptr;
2513            })
2514       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2515       .def(
2516           "__str__",
2517           [](PyValue &self) {
2518             PyPrintAccumulator printAccum;
2519             printAccum.parts.append("Value(");
2520             mlirValuePrint(self.get(), printAccum.getCallback(),
2521                            printAccum.getUserData());
2522             printAccum.parts.append(")");
2523             return printAccum.join();
2524           },
2525           kValueDunderStrDocstring)
2526       .def_property_readonly("type", [](PyValue &self) {
2527         return PyType(self.getParentOperation()->getContext(),
2528                       mlirValueGetType(self.get()));
2529       });
2530   PyBlockArgument::bind(m);
2531   PyOpResult::bind(m);
2532 
2533   // Container bindings.
2534   PyBlockArgumentList::bind(m);
2535   PyBlockIterator::bind(m);
2536   PyBlockList::bind(m);
2537   PyOperationIterator::bind(m);
2538   PyOperationList::bind(m);
2539   PyOpAttributeMap::bind(m);
2540   PyOpOperandList::bind(m);
2541   PyOpResultList::bind(m);
2542   PyRegionIterator::bind(m);
2543   PyRegionList::bind(m);
2544 
2545   // Debug bindings.
2546   PyGlobalDebugFlag::bind(m);
2547 }
2548