1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include <pybind11/stl.h>
22 
23 namespace py = pybind11;
24 using namespace mlir;
25 using namespace mlir::python;
26 
27 using llvm::SmallVector;
28 using llvm::StringRef;
29 using llvm::Twine;
30 
31 //------------------------------------------------------------------------------
32 // Docstrings (trivial, non-duplicated docstrings are included inline).
33 //------------------------------------------------------------------------------
34 
35 static const char kContextParseTypeDocstring[] =
36     R"(Parses the assembly form of a type.
37 
38 Returns a Type object or raises a ValueError if the type cannot be parsed.
39 
40 See also: https://mlir.llvm.org/docs/LangRef/#type-system
41 )";
42 
43 static const char kContextGetFileLocationDocstring[] =
44     R"(Gets a Location representing a file, line and column)";
45 
46 static const char kModuleParseDocstring[] =
47     R"(Parses a module's assembly format from a string.
48 
49 Returns a new MlirModule or raises a ValueError if the parsing fails.
50 
51 See also: https://mlir.llvm.org/docs/LangRef/
52 )";
53 
54 static const char kOperationCreateDocstring[] =
55     R"(Creates a new operation.
56 
57 Args:
58   name: Operation name (e.g. "dialect.operation").
59   results: Sequence of Type representing op result types.
60   attributes: Dict of str:Attribute.
61   successors: List of Block for the operation's successors.
62   regions: Number of regions to create.
63   location: A Location object (defaults to resolve from context manager).
64   ip: An InsertionPoint (defaults to resolve from context manager or set to
65     False to disable insertion, even with an insertion point set in the
66     context manager).
67 Returns:
68   A new "detached" Operation object. Detached operations can be added
69   to blocks, which causes them to become "attached."
70 )";
71 
72 static const char kOperationPrintDocstring[] =
73     R"(Prints the assembly form of the operation to a file like object.
74 
75 Args:
76   file: The file like object to write to. Defaults to sys.stdout.
77   binary: Whether to write bytes (True) or str (False). Defaults to False.
78   large_elements_limit: Whether to elide elements attributes above this
79     number of elements. Defaults to None (no limit).
80   enable_debug_info: Whether to print debug/location information. Defaults
81     to False.
82   pretty_debug_info: Whether to format debug information for easier reading
83     by a human (warning: the result is unparseable).
84   print_generic_op_form: Whether to print the generic assembly forms of all
85     ops. Defaults to False.
86   use_local_Scope: Whether to print in a way that is more optimized for
87     multi-threaded access but may not be consistent with how the overall
88     module prints.
89 )";
90 
91 static const char kOperationGetAsmDocstring[] =
92     R"(Gets the assembly form of the operation with all options available.
93 
94 Args:
95   binary: Whether to return a bytes (True) or str (False) object. Defaults to
96     False.
97   ... others ...: See the print() method for common keyword arguments for
98     configuring the printout.
99 Returns:
100   Either a bytes or str object, depending on the setting of the 'binary'
101   argument.
102 )";
103 
104 static const char kOperationStrDunderDocstring[] =
105     R"(Gets the assembly form of the operation with default options.
106 
107 If more advanced control over the assembly formatting or I/O options is needed,
108 use the dedicated print or get_asm method, which supports keyword arguments to
109 customize behavior.
110 )";
111 
112 static const char kDumpDocstring[] =
113     R"(Dumps a debug representation of the object to stderr.)";
114 
115 static const char kAppendBlockDocstring[] =
116     R"(Appends a new block, with argument types as positional args.
117 
118 Returns:
119   The created block.
120 )";
121 
122 static const char kValueDunderStrDocstring[] =
123     R"(Returns the string form of the value.
124 
125 If the value is a block argument, this is the assembly form of its type and the
126 position in the argument list. If the value is an operation result, this is
127 equivalent to printing the operation that produced it.
128 )";
129 
130 //------------------------------------------------------------------------------
131 // Utilities.
132 //------------------------------------------------------------------------------
133 
134 /// Helper for creating an @classmethod.
135 template <class Func, typename... Args>
136 py::object classmethod(Func f, Args... args) {
137   py::object cf = py::cpp_function(f, args...);
138   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
139 }
140 
141 static py::object
142 createCustomDialectWrapper(const std::string &dialectNamespace,
143                            py::object dialectDescriptor) {
144   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
145   if (!dialectClass) {
146     // Use the base class.
147     return py::cast(PyDialect(std::move(dialectDescriptor)));
148   }
149 
150   // Create the custom implementation.
151   return (*dialectClass)(std::move(dialectDescriptor));
152 }
153 
154 static MlirStringRef toMlirStringRef(const std::string &s) {
155   return mlirStringRefCreate(s.data(), s.size());
156 }
157 
158 /// Wrapper for the global LLVM debugging flag.
159 struct PyGlobalDebugFlag {
160   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
161 
162   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
163 
164   static void bind(py::module &m) {
165     // Debug flags.
166     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
167         .def_property_static("flag", &PyGlobalDebugFlag::get,
168                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
169   }
170 };
171 
172 //------------------------------------------------------------------------------
173 // Collections.
174 //------------------------------------------------------------------------------
175 
176 namespace {
177 
178 class PyRegionIterator {
179 public:
180   PyRegionIterator(PyOperationRef operation)
181       : operation(std::move(operation)) {}
182 
183   PyRegionIterator &dunderIter() { return *this; }
184 
185   PyRegion dunderNext() {
186     operation->checkValid();
187     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
188       throw py::stop_iteration();
189     }
190     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
191     return PyRegion(operation, region);
192   }
193 
194   static void bind(py::module &m) {
195     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
196         .def("__iter__", &PyRegionIterator::dunderIter)
197         .def("__next__", &PyRegionIterator::dunderNext);
198   }
199 
200 private:
201   PyOperationRef operation;
202   int nextIndex = 0;
203 };
204 
205 /// Regions of an op are fixed length and indexed numerically so are represented
206 /// with a sequence-like container.
207 class PyRegionList {
208 public:
209   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
210 
211   intptr_t dunderLen() {
212     operation->checkValid();
213     return mlirOperationGetNumRegions(operation->get());
214   }
215 
216   PyRegion dunderGetItem(intptr_t index) {
217     // dunderLen checks validity.
218     if (index < 0 || index >= dunderLen()) {
219       throw SetPyError(PyExc_IndexError,
220                        "attempt to access out of bounds region");
221     }
222     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
223     return PyRegion(operation, region);
224   }
225 
226   static void bind(py::module &m) {
227     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
228         .def("__len__", &PyRegionList::dunderLen)
229         .def("__getitem__", &PyRegionList::dunderGetItem);
230   }
231 
232 private:
233   PyOperationRef operation;
234 };
235 
236 class PyBlockIterator {
237 public:
238   PyBlockIterator(PyOperationRef operation, MlirBlock next)
239       : operation(std::move(operation)), next(next) {}
240 
241   PyBlockIterator &dunderIter() { return *this; }
242 
243   PyBlock dunderNext() {
244     operation->checkValid();
245     if (mlirBlockIsNull(next)) {
246       throw py::stop_iteration();
247     }
248 
249     PyBlock returnBlock(operation, next);
250     next = mlirBlockGetNextInRegion(next);
251     return returnBlock;
252   }
253 
254   static void bind(py::module &m) {
255     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
256         .def("__iter__", &PyBlockIterator::dunderIter)
257         .def("__next__", &PyBlockIterator::dunderNext);
258   }
259 
260 private:
261   PyOperationRef operation;
262   MlirBlock next;
263 };
264 
265 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
266 /// we present them as a more full-featured list-like container but optimize
267 /// it for forward iteration. Blocks are always owned by a region.
268 class PyBlockList {
269 public:
270   PyBlockList(PyOperationRef operation, MlirRegion region)
271       : operation(std::move(operation)), region(region) {}
272 
273   PyBlockIterator dunderIter() {
274     operation->checkValid();
275     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
276   }
277 
278   intptr_t dunderLen() {
279     operation->checkValid();
280     intptr_t count = 0;
281     MlirBlock block = mlirRegionGetFirstBlock(region);
282     while (!mlirBlockIsNull(block)) {
283       count += 1;
284       block = mlirBlockGetNextInRegion(block);
285     }
286     return count;
287   }
288 
289   PyBlock dunderGetItem(intptr_t index) {
290     operation->checkValid();
291     if (index < 0) {
292       throw SetPyError(PyExc_IndexError,
293                        "attempt to access out of bounds block");
294     }
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       if (index == 0) {
298         return PyBlock(operation, block);
299       }
300       block = mlirBlockGetNextInRegion(block);
301       index -= 1;
302     }
303     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
304   }
305 
306   PyBlock appendBlock(py::args pyArgTypes) {
307     operation->checkValid();
308     llvm::SmallVector<MlirType, 4> argTypes;
309     argTypes.reserve(pyArgTypes.size());
310     for (auto &pyArg : pyArgTypes) {
311       argTypes.push_back(pyArg.cast<PyType &>());
312     }
313 
314     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
315     mlirRegionAppendOwnedBlock(region, block);
316     return PyBlock(operation, block);
317   }
318 
319   static void bind(py::module &m) {
320     py::class_<PyBlockList>(m, "BlockList", py::module_local())
321         .def("__getitem__", &PyBlockList::dunderGetItem)
322         .def("__iter__", &PyBlockList::dunderIter)
323         .def("__len__", &PyBlockList::dunderLen)
324         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
325   }
326 
327 private:
328   PyOperationRef operation;
329   MlirRegion region;
330 };
331 
332 class PyOperationIterator {
333 public:
334   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
335       : parentOperation(std::move(parentOperation)), next(next) {}
336 
337   PyOperationIterator &dunderIter() { return *this; }
338 
339   py::object dunderNext() {
340     parentOperation->checkValid();
341     if (mlirOperationIsNull(next)) {
342       throw py::stop_iteration();
343     }
344 
345     PyOperationRef returnOperation =
346         PyOperation::forOperation(parentOperation->getContext(), next);
347     next = mlirOperationGetNextInBlock(next);
348     return returnOperation->createOpView();
349   }
350 
351   static void bind(py::module &m) {
352     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
353         .def("__iter__", &PyOperationIterator::dunderIter)
354         .def("__next__", &PyOperationIterator::dunderNext);
355   }
356 
357 private:
358   PyOperationRef parentOperation;
359   MlirOperation next;
360 };
361 
362 /// Operations are exposed by the C-API as a forward-only linked list. In
363 /// Python, we present them as a more full-featured list-like container but
364 /// optimize it for forward iteration. Iterable operations are always owned
365 /// by a block.
366 class PyOperationList {
367 public:
368   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
369       : parentOperation(std::move(parentOperation)), block(block) {}
370 
371   PyOperationIterator dunderIter() {
372     parentOperation->checkValid();
373     return PyOperationIterator(parentOperation,
374                                mlirBlockGetFirstOperation(block));
375   }
376 
377   intptr_t dunderLen() {
378     parentOperation->checkValid();
379     intptr_t count = 0;
380     MlirOperation childOp = mlirBlockGetFirstOperation(block);
381     while (!mlirOperationIsNull(childOp)) {
382       count += 1;
383       childOp = mlirOperationGetNextInBlock(childOp);
384     }
385     return count;
386   }
387 
388   py::object dunderGetItem(intptr_t index) {
389     parentOperation->checkValid();
390     if (index < 0) {
391       throw SetPyError(PyExc_IndexError,
392                        "attempt to access out of bounds operation");
393     }
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       if (index == 0) {
397         return PyOperation::forOperation(parentOperation->getContext(), childOp)
398             ->createOpView();
399       }
400       childOp = mlirOperationGetNextInBlock(childOp);
401       index -= 1;
402     }
403     throw SetPyError(PyExc_IndexError,
404                      "attempt to access out of bounds operation");
405   }
406 
407   static void bind(py::module &m) {
408     py::class_<PyOperationList>(m, "OperationList", py::module_local())
409         .def("__getitem__", &PyOperationList::dunderGetItem)
410         .def("__iter__", &PyOperationList::dunderIter)
411         .def("__len__", &PyOperationList::dunderLen);
412   }
413 
414 private:
415   PyOperationRef parentOperation;
416   MlirBlock block;
417 };
418 
419 } // namespace
420 
421 //------------------------------------------------------------------------------
422 // PyMlirContext
423 //------------------------------------------------------------------------------
424 
425 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
426   py::gil_scoped_acquire acquire;
427   auto &liveContexts = getLiveContexts();
428   liveContexts[context.ptr] = this;
429 }
430 
431 PyMlirContext::~PyMlirContext() {
432   // Note that the only public way to construct an instance is via the
433   // forContext method, which always puts the associated handle into
434   // liveContexts.
435   py::gil_scoped_acquire acquire;
436   getLiveContexts().erase(context.ptr);
437   mlirContextDestroy(context);
438 }
439 
440 py::object PyMlirContext::getCapsule() {
441   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
442 }
443 
444 py::object PyMlirContext::createFromCapsule(py::object capsule) {
445   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
446   if (mlirContextIsNull(rawContext))
447     throw py::error_already_set();
448   return forContext(rawContext).releaseObject();
449 }
450 
451 PyMlirContext *PyMlirContext::createNewContextForInit() {
452   MlirContext context = mlirContextCreate();
453   mlirRegisterAllDialects(context);
454   return new PyMlirContext(context);
455 }
456 
457 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
458   py::gil_scoped_acquire acquire;
459   auto &liveContexts = getLiveContexts();
460   auto it = liveContexts.find(context.ptr);
461   if (it == liveContexts.end()) {
462     // Create.
463     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
464     py::object pyRef = py::cast(unownedContextWrapper);
465     assert(pyRef && "cast to py::object failed");
466     liveContexts[context.ptr] = unownedContextWrapper;
467     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
468   }
469   // Use existing.
470   py::object pyRef = py::cast(it->second);
471   return PyMlirContextRef(it->second, std::move(pyRef));
472 }
473 
474 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
475   static LiveContextMap liveContexts;
476   return liveContexts;
477 }
478 
479 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
480 
481 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
482 
483 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
484 
485 pybind11::object PyMlirContext::contextEnter() {
486   return PyThreadContextEntry::pushContext(*this);
487 }
488 
489 void PyMlirContext::contextExit(pybind11::object excType,
490                                 pybind11::object excVal,
491                                 pybind11::object excTb) {
492   PyThreadContextEntry::popContext(*this);
493 }
494 
495 PyMlirContext &DefaultingPyMlirContext::resolve() {
496   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
497   if (!context) {
498     throw SetPyError(
499         PyExc_RuntimeError,
500         "An MLIR function requires a Context but none was provided in the call "
501         "or from the surrounding environment. Either pass to the function with "
502         "a 'context=' argument or establish a default using 'with Context():'");
503   }
504   return *context;
505 }
506 
507 //------------------------------------------------------------------------------
508 // PyThreadContextEntry management
509 //------------------------------------------------------------------------------
510 
511 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
512   static thread_local std::vector<PyThreadContextEntry> stack;
513   return stack;
514 }
515 
516 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
517   auto &stack = getStack();
518   if (stack.empty())
519     return nullptr;
520   return &stack.back();
521 }
522 
523 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
524                                 py::object insertionPoint,
525                                 py::object location) {
526   auto &stack = getStack();
527   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
528                      std::move(location));
529   // If the new stack has more than one entry and the context of the new top
530   // entry matches the previous, copy the insertionPoint and location from the
531   // previous entry if missing from the new top entry.
532   if (stack.size() > 1) {
533     auto &prev = *(stack.rbegin() + 1);
534     auto &current = stack.back();
535     if (current.context.is(prev.context)) {
536       // Default non-context objects from the previous entry.
537       if (!current.insertionPoint)
538         current.insertionPoint = prev.insertionPoint;
539       if (!current.location)
540         current.location = prev.location;
541     }
542   }
543 }
544 
545 PyMlirContext *PyThreadContextEntry::getContext() {
546   if (!context)
547     return nullptr;
548   return py::cast<PyMlirContext *>(context);
549 }
550 
551 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
552   if (!insertionPoint)
553     return nullptr;
554   return py::cast<PyInsertionPoint *>(insertionPoint);
555 }
556 
557 PyLocation *PyThreadContextEntry::getLocation() {
558   if (!location)
559     return nullptr;
560   return py::cast<PyLocation *>(location);
561 }
562 
563 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
564   auto *tos = getTopOfStack();
565   return tos ? tos->getContext() : nullptr;
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
569   auto *tos = getTopOfStack();
570   return tos ? tos->getInsertionPoint() : nullptr;
571 }
572 
573 PyLocation *PyThreadContextEntry::getDefaultLocation() {
574   auto *tos = getTopOfStack();
575   return tos ? tos->getLocation() : nullptr;
576 }
577 
578 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
579   py::object contextObj = py::cast(context);
580   push(FrameKind::Context, /*context=*/contextObj,
581        /*insertionPoint=*/py::object(),
582        /*location=*/py::object());
583   return contextObj;
584 }
585 
586 void PyThreadContextEntry::popContext(PyMlirContext &context) {
587   auto &stack = getStack();
588   if (stack.empty())
589     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
590   auto &tos = stack.back();
591   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
592     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
593   stack.pop_back();
594 }
595 
596 py::object
597 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
598   py::object contextObj =
599       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
600   py::object insertionPointObj = py::cast(insertionPoint);
601   push(FrameKind::InsertionPoint,
602        /*context=*/contextObj,
603        /*insertionPoint=*/insertionPointObj,
604        /*location=*/py::object());
605   return insertionPointObj;
606 }
607 
608 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
609   auto &stack = getStack();
610   if (stack.empty())
611     throw SetPyError(PyExc_RuntimeError,
612                      "Unbalanced InsertionPoint enter/exit");
613   auto &tos = stack.back();
614   if (tos.frameKind != FrameKind::InsertionPoint &&
615       tos.getInsertionPoint() != &insertionPoint)
616     throw SetPyError(PyExc_RuntimeError,
617                      "Unbalanced InsertionPoint enter/exit");
618   stack.pop_back();
619 }
620 
621 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
622   py::object contextObj = location.getContext().getObject();
623   py::object locationObj = py::cast(location);
624   push(FrameKind::Location, /*context=*/contextObj,
625        /*insertionPoint=*/py::object(),
626        /*location=*/locationObj);
627   return locationObj;
628 }
629 
630 void PyThreadContextEntry::popLocation(PyLocation &location) {
631   auto &stack = getStack();
632   if (stack.empty())
633     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
634   auto &tos = stack.back();
635   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
636     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
637   stack.pop_back();
638 }
639 
640 //------------------------------------------------------------------------------
641 // PyDialect, PyDialectDescriptor, PyDialects
642 //------------------------------------------------------------------------------
643 
644 MlirDialect PyDialects::getDialectForKey(const std::string &key,
645                                          bool attrError) {
646   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
647                                                     {key.data(), key.size()});
648   if (mlirDialectIsNull(dialect)) {
649     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
650                      Twine("Dialect '") + key + "' not found");
651   }
652   return dialect;
653 }
654 
655 //------------------------------------------------------------------------------
656 // PyLocation
657 //------------------------------------------------------------------------------
658 
659 py::object PyLocation::getCapsule() {
660   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
661 }
662 
663 PyLocation PyLocation::createFromCapsule(py::object capsule) {
664   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
665   if (mlirLocationIsNull(rawLoc))
666     throw py::error_already_set();
667   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
668                     rawLoc);
669 }
670 
671 py::object PyLocation::contextEnter() {
672   return PyThreadContextEntry::pushLocation(*this);
673 }
674 
675 void PyLocation::contextExit(py::object excType, py::object excVal,
676                              py::object excTb) {
677   PyThreadContextEntry::popLocation(*this);
678 }
679 
680 PyLocation &DefaultingPyLocation::resolve() {
681   auto *location = PyThreadContextEntry::getDefaultLocation();
682   if (!location) {
683     throw SetPyError(
684         PyExc_RuntimeError,
685         "An MLIR function requires a Location but none was provided in the "
686         "call or from the surrounding environment. Either pass to the function "
687         "with a 'loc=' argument or establish a default using 'with loc:'");
688   }
689   return *location;
690 }
691 
692 //------------------------------------------------------------------------------
693 // PyModule
694 //------------------------------------------------------------------------------
695 
696 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
697     : BaseContextObject(std::move(contextRef)), module(module) {}
698 
699 PyModule::~PyModule() {
700   py::gil_scoped_acquire acquire;
701   auto &liveModules = getContext()->liveModules;
702   assert(liveModules.count(module.ptr) == 1 &&
703          "destroying module not in live map");
704   liveModules.erase(module.ptr);
705   mlirModuleDestroy(module);
706 }
707 
708 PyModuleRef PyModule::forModule(MlirModule module) {
709   MlirContext context = mlirModuleGetContext(module);
710   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
711 
712   py::gil_scoped_acquire acquire;
713   auto &liveModules = contextRef->liveModules;
714   auto it = liveModules.find(module.ptr);
715   if (it == liveModules.end()) {
716     // Create.
717     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
718     // Note that the default return value policy on cast is automatic_reference,
719     // which does not take ownership (delete will not be called).
720     // Just be explicit.
721     py::object pyRef =
722         py::cast(unownedModule, py::return_value_policy::take_ownership);
723     unownedModule->handle = pyRef;
724     liveModules[module.ptr] =
725         std::make_pair(unownedModule->handle, unownedModule);
726     return PyModuleRef(unownedModule, std::move(pyRef));
727   }
728   // Use existing.
729   PyModule *existing = it->second.second;
730   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
731   return PyModuleRef(existing, std::move(pyRef));
732 }
733 
734 py::object PyModule::createFromCapsule(py::object capsule) {
735   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
736   if (mlirModuleIsNull(rawModule))
737     throw py::error_already_set();
738   return forModule(rawModule).releaseObject();
739 }
740 
741 py::object PyModule::getCapsule() {
742   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
743 }
744 
745 //------------------------------------------------------------------------------
746 // PyOperation
747 //------------------------------------------------------------------------------
748 
749 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
750     : BaseContextObject(std::move(contextRef)), operation(operation) {}
751 
752 PyOperation::~PyOperation() {
753   // If the operation has already been invalidated there is nothing to do.
754   if (!valid)
755     return;
756   auto &liveOperations = getContext()->liveOperations;
757   assert(liveOperations.count(operation.ptr) == 1 &&
758          "destroying operation not in live map");
759   liveOperations.erase(operation.ptr);
760   if (!isAttached()) {
761     mlirOperationDestroy(operation);
762   }
763 }
764 
765 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
766                                            MlirOperation operation,
767                                            py::object parentKeepAlive) {
768   auto &liveOperations = contextRef->liveOperations;
769   // Create.
770   PyOperation *unownedOperation =
771       new PyOperation(std::move(contextRef), operation);
772   // Note that the default return value policy on cast is automatic_reference,
773   // which does not take ownership (delete will not be called).
774   // Just be explicit.
775   py::object pyRef =
776       py::cast(unownedOperation, py::return_value_policy::take_ownership);
777   unownedOperation->handle = pyRef;
778   if (parentKeepAlive) {
779     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
780   }
781   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
782   return PyOperationRef(unownedOperation, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
786                                          MlirOperation operation,
787                                          py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   auto it = liveOperations.find(operation.ptr);
790   if (it == liveOperations.end()) {
791     // Create.
792     return createInstance(std::move(contextRef), operation,
793                           std::move(parentKeepAlive));
794   }
795   // Use existing.
796   PyOperation *existing = it->second.second;
797   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
798   return PyOperationRef(existing, std::move(pyRef));
799 }
800 
801 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
802                                            MlirOperation operation,
803                                            py::object parentKeepAlive) {
804   auto &liveOperations = contextRef->liveOperations;
805   assert(liveOperations.count(operation.ptr) == 0 &&
806          "cannot create detached operation that already exists");
807   (void)liveOperations;
808 
809   PyOperationRef created = createInstance(std::move(contextRef), operation,
810                                           std::move(parentKeepAlive));
811   created->attached = false;
812   return created;
813 }
814 
815 void PyOperation::checkValid() const {
816   if (!valid) {
817     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
818   }
819 }
820 
821 void PyOperationBase::print(py::object fileObject, bool binary,
822                             llvm::Optional<int64_t> largeElementsLimit,
823                             bool enableDebugInfo, bool prettyDebugInfo,
824                             bool printGenericOpForm, bool useLocalScope) {
825   PyOperation &operation = getOperation();
826   operation.checkValid();
827   if (fileObject.is_none())
828     fileObject = py::module::import("sys").attr("stdout");
829 
830   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
831     fileObject.attr("write")("// Verification failed, printing generic form\n");
832     printGenericOpForm = true;
833   }
834 
835   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
836   if (largeElementsLimit)
837     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
838   if (enableDebugInfo)
839     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
840   if (printGenericOpForm)
841     mlirOpPrintingFlagsPrintGenericOpForm(flags);
842 
843   PyFileAccumulator accum(fileObject, binary);
844   py::gil_scoped_release();
845   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
846                               accum.getUserData());
847   mlirOpPrintingFlagsDestroy(flags);
848 }
849 
850 py::object PyOperationBase::getAsm(bool binary,
851                                    llvm::Optional<int64_t> largeElementsLimit,
852                                    bool enableDebugInfo, bool prettyDebugInfo,
853                                    bool printGenericOpForm,
854                                    bool useLocalScope) {
855   py::object fileObject;
856   if (binary) {
857     fileObject = py::module::import("io").attr("BytesIO")();
858   } else {
859     fileObject = py::module::import("io").attr("StringIO")();
860   }
861   print(fileObject, /*binary=*/binary,
862         /*largeElementsLimit=*/largeElementsLimit,
863         /*enableDebugInfo=*/enableDebugInfo,
864         /*prettyDebugInfo=*/prettyDebugInfo,
865         /*printGenericOpForm=*/printGenericOpForm,
866         /*useLocalScope=*/useLocalScope);
867 
868   return fileObject.attr("getvalue")();
869 }
870 
871 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
872   checkValid();
873   if (!isAttached())
874     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
875   MlirOperation operation = mlirOperationGetParentOperation(get());
876   if (mlirOperationIsNull(operation))
877     return {};
878   return PyOperation::forOperation(getContext(), operation);
879 }
880 
881 PyBlock PyOperation::getBlock() {
882   checkValid();
883   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
884   MlirBlock block = mlirOperationGetBlock(get());
885   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
886   assert(parentOperation && "Operation has no parent");
887   return PyBlock{std::move(*parentOperation), block};
888 }
889 
890 py::object PyOperation::getCapsule() {
891   checkValid();
892   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
893 }
894 
895 py::object PyOperation::createFromCapsule(py::object capsule) {
896   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
897   if (mlirOperationIsNull(rawOperation))
898     throw py::error_already_set();
899   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
900   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
901       .releaseObject();
902 }
903 
904 py::object PyOperation::create(
905     std::string name, llvm::Optional<std::vector<PyType *>> results,
906     llvm::Optional<std::vector<PyValue *>> operands,
907     llvm::Optional<py::dict> attributes,
908     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
909     DefaultingPyLocation location, py::object maybeIp) {
910   llvm::SmallVector<MlirValue, 4> mlirOperands;
911   llvm::SmallVector<MlirType, 4> mlirResults;
912   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
913   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
914 
915   // General parameter validation.
916   if (regions < 0)
917     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
918 
919   // Unpack/validate operands.
920   if (operands) {
921     mlirOperands.reserve(operands->size());
922     for (PyValue *operand : *operands) {
923       if (!operand)
924         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
925       mlirOperands.push_back(operand->get());
926     }
927   }
928 
929   // Unpack/validate results.
930   if (results) {
931     mlirResults.reserve(results->size());
932     for (PyType *result : *results) {
933       // TODO: Verify result type originate from the same context.
934       if (!result)
935         throw SetPyError(PyExc_ValueError, "result type cannot be None");
936       mlirResults.push_back(*result);
937     }
938   }
939   // Unpack/validate attributes.
940   if (attributes) {
941     mlirAttributes.reserve(attributes->size());
942     for (auto &it : *attributes) {
943       std::string key;
944       try {
945         key = it.first.cast<std::string>();
946       } catch (py::cast_error &err) {
947         std::string msg = "Invalid attribute key (not a string) when "
948                           "attempting to create the operation \"" +
949                           name + "\" (" + err.what() + ")";
950         throw py::cast_error(msg);
951       }
952       try {
953         auto &attribute = it.second.cast<PyAttribute &>();
954         // TODO: Verify attribute originates from the same context.
955         mlirAttributes.emplace_back(std::move(key), attribute);
956       } catch (py::reference_cast_error &) {
957         // This exception seems thrown when the value is "None".
958         std::string msg =
959             "Found an invalid (`None`?) attribute value for the key \"" + key +
960             "\" when attempting to create the operation \"" + name + "\"";
961         throw py::cast_error(msg);
962       } catch (py::cast_error &err) {
963         std::string msg = "Invalid attribute value for the key \"" + key +
964                           "\" when attempting to create the operation \"" +
965                           name + "\" (" + err.what() + ")";
966         throw py::cast_error(msg);
967       }
968     }
969   }
970   // Unpack/validate successors.
971   if (successors) {
972     mlirSuccessors.reserve(successors->size());
973     for (auto *successor : *successors) {
974       // TODO: Verify successor originate from the same context.
975       if (!successor)
976         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
977       mlirSuccessors.push_back(successor->get());
978     }
979   }
980 
981   // Apply unpacked/validated to the operation state. Beyond this
982   // point, exceptions cannot be thrown or else the state will leak.
983   MlirOperationState state =
984       mlirOperationStateGet(toMlirStringRef(name), location);
985   if (!mlirOperands.empty())
986     mlirOperationStateAddOperands(&state, mlirOperands.size(),
987                                   mlirOperands.data());
988   if (!mlirResults.empty())
989     mlirOperationStateAddResults(&state, mlirResults.size(),
990                                  mlirResults.data());
991   if (!mlirAttributes.empty()) {
992     // Note that the attribute names directly reference bytes in
993     // mlirAttributes, so that vector must not be changed from here
994     // on.
995     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
996     mlirNamedAttributes.reserve(mlirAttributes.size());
997     for (auto &it : mlirAttributes)
998       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
999           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1000                             toMlirStringRef(it.first)),
1001           it.second));
1002     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1003                                     mlirNamedAttributes.data());
1004   }
1005   if (!mlirSuccessors.empty())
1006     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1007                                     mlirSuccessors.data());
1008   if (regions) {
1009     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1010     mlirRegions.resize(regions);
1011     for (int i = 0; i < regions; ++i)
1012       mlirRegions[i] = mlirRegionCreate();
1013     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1014                                       mlirRegions.data());
1015   }
1016 
1017   // Construct the operation.
1018   MlirOperation operation = mlirOperationCreate(&state);
1019   PyOperationRef created =
1020       PyOperation::createDetached(location->getContext(), operation);
1021 
1022   // InsertPoint active?
1023   if (!maybeIp.is(py::cast(false))) {
1024     PyInsertionPoint *ip;
1025     if (maybeIp.is_none()) {
1026       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1027     } else {
1028       ip = py::cast<PyInsertionPoint *>(maybeIp);
1029     }
1030     if (ip)
1031       ip->insert(*created.get());
1032   }
1033 
1034   return created->createOpView();
1035 }
1036 
1037 py::object PyOperation::createOpView() {
1038   checkValid();
1039   MlirIdentifier ident = mlirOperationGetName(get());
1040   MlirStringRef identStr = mlirIdentifierStr(ident);
1041   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1042       StringRef(identStr.data, identStr.length));
1043   if (opViewClass)
1044     return (*opViewClass)(getRef().getObject());
1045   return py::cast(PyOpView(getRef().getObject()));
1046 }
1047 
1048 void PyOperation::erase() {
1049   checkValid();
1050   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1051   // Python reference to a child operation is live. All children should also
1052   // have their `valid` bit set to false.
1053   auto &liveOperations = getContext()->liveOperations;
1054   if (liveOperations.count(operation.ptr))
1055     liveOperations.erase(operation.ptr);
1056   mlirOperationDestroy(operation);
1057   valid = false;
1058 }
1059 
1060 //------------------------------------------------------------------------------
1061 // PyOpView
1062 //------------------------------------------------------------------------------
1063 
1064 py::object
1065 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1066                        py::list operandList,
1067                        llvm::Optional<py::dict> attributes,
1068                        llvm::Optional<std::vector<PyBlock *>> successors,
1069                        llvm::Optional<int> regions,
1070                        DefaultingPyLocation location, py::object maybeIp) {
1071   PyMlirContextRef context = location->getContext();
1072   // Class level operation construction metadata.
1073   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1074   // Operand and result segment specs are either none, which does no
1075   // variadic unpacking, or a list of ints with segment sizes, where each
1076   // element is either a positive number (typically 1 for a scalar) or -1 to
1077   // indicate that it is derived from the length of the same-indexed operand
1078   // or result (implying that it is a list at that position).
1079   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1080   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1081 
1082   std::vector<uint32_t> operandSegmentLengths;
1083   std::vector<uint32_t> resultSegmentLengths;
1084 
1085   // Validate/determine region count.
1086   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1087   int opMinRegionCount = std::get<0>(opRegionSpec);
1088   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1089   if (!regions) {
1090     regions = opMinRegionCount;
1091   }
1092   if (*regions < opMinRegionCount) {
1093     throw py::value_error(
1094         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1095          llvm::Twine(opMinRegionCount) +
1096          " regions but was built with regions=" + llvm::Twine(*regions))
1097             .str());
1098   }
1099   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1100     throw py::value_error(
1101         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1102          llvm::Twine(opMinRegionCount) +
1103          " regions but was built with regions=" + llvm::Twine(*regions))
1104             .str());
1105   }
1106 
1107   // Unpack results.
1108   std::vector<PyType *> resultTypes;
1109   resultTypes.reserve(resultTypeList.size());
1110   if (resultSegmentSpecObj.is_none()) {
1111     // Non-variadic result unpacking.
1112     for (auto it : llvm::enumerate(resultTypeList)) {
1113       try {
1114         resultTypes.push_back(py::cast<PyType *>(it.value()));
1115         if (!resultTypes.back())
1116           throw py::cast_error();
1117       } catch (py::cast_error &err) {
1118         throw py::value_error((llvm::Twine("Result ") +
1119                                llvm::Twine(it.index()) + " of operation \"" +
1120                                name + "\" must be a Type (" + err.what() + ")")
1121                                   .str());
1122       }
1123     }
1124   } else {
1125     // Sized result unpacking.
1126     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1127     if (resultSegmentSpec.size() != resultTypeList.size()) {
1128       throw py::value_error((llvm::Twine("Operation \"") + name +
1129                              "\" requires " +
1130                              llvm::Twine(resultSegmentSpec.size()) +
1131                              "result segments but was provided " +
1132                              llvm::Twine(resultTypeList.size()))
1133                                 .str());
1134     }
1135     resultSegmentLengths.reserve(resultTypeList.size());
1136     for (auto it :
1137          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1138       int segmentSpec = std::get<1>(it.value());
1139       if (segmentSpec == 1 || segmentSpec == 0) {
1140         // Unpack unary element.
1141         try {
1142           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1143           if (resultType) {
1144             resultTypes.push_back(resultType);
1145             resultSegmentLengths.push_back(1);
1146           } else if (segmentSpec == 0) {
1147             // Allowed to be optional.
1148             resultSegmentLengths.push_back(0);
1149           } else {
1150             throw py::cast_error("was None and result is not optional");
1151           }
1152         } catch (py::cast_error &err) {
1153           throw py::value_error((llvm::Twine("Result ") +
1154                                  llvm::Twine(it.index()) + " of operation \"" +
1155                                  name + "\" must be a Type (" + err.what() +
1156                                  ")")
1157                                     .str());
1158         }
1159       } else if (segmentSpec == -1) {
1160         // Unpack sequence by appending.
1161         try {
1162           if (std::get<0>(it.value()).is_none()) {
1163             // Treat it as an empty list.
1164             resultSegmentLengths.push_back(0);
1165           } else {
1166             // Unpack the list.
1167             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1168             for (py::object segmentItem : segment) {
1169               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1170               if (!resultTypes.back()) {
1171                 throw py::cast_error("contained a None item");
1172               }
1173             }
1174             resultSegmentLengths.push_back(segment.size());
1175           }
1176         } catch (std::exception &err) {
1177           // NOTE: Sloppy to be using a catch-all here, but there are at least
1178           // three different unrelated exceptions that can be thrown in the
1179           // above "casts". Just keep the scope above small and catch them all.
1180           throw py::value_error((llvm::Twine("Result ") +
1181                                  llvm::Twine(it.index()) + " of operation \"" +
1182                                  name + "\" must be a Sequence of Types (" +
1183                                  err.what() + ")")
1184                                     .str());
1185         }
1186       } else {
1187         throw py::value_error("Unexpected segment spec");
1188       }
1189     }
1190   }
1191 
1192   // Unpack operands.
1193   std::vector<PyValue *> operands;
1194   operands.reserve(operands.size());
1195   if (operandSegmentSpecObj.is_none()) {
1196     // Non-sized operand unpacking.
1197     for (auto it : llvm::enumerate(operandList)) {
1198       try {
1199         operands.push_back(py::cast<PyValue *>(it.value()));
1200         if (!operands.back())
1201           throw py::cast_error();
1202       } catch (py::cast_error &err) {
1203         throw py::value_error((llvm::Twine("Operand ") +
1204                                llvm::Twine(it.index()) + " of operation \"" +
1205                                name + "\" must be a Value (" + err.what() + ")")
1206                                   .str());
1207       }
1208     }
1209   } else {
1210     // Sized operand unpacking.
1211     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1212     if (operandSegmentSpec.size() != operandList.size()) {
1213       throw py::value_error((llvm::Twine("Operation \"") + name +
1214                              "\" requires " +
1215                              llvm::Twine(operandSegmentSpec.size()) +
1216                              "operand segments but was provided " +
1217                              llvm::Twine(operandList.size()))
1218                                 .str());
1219     }
1220     operandSegmentLengths.reserve(operandList.size());
1221     for (auto it :
1222          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1223       int segmentSpec = std::get<1>(it.value());
1224       if (segmentSpec == 1 || segmentSpec == 0) {
1225         // Unpack unary element.
1226         try {
1227           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1228           if (operandValue) {
1229             operands.push_back(operandValue);
1230             operandSegmentLengths.push_back(1);
1231           } else if (segmentSpec == 0) {
1232             // Allowed to be optional.
1233             operandSegmentLengths.push_back(0);
1234           } else {
1235             throw py::cast_error("was None and operand is not optional");
1236           }
1237         } catch (py::cast_error &err) {
1238           throw py::value_error((llvm::Twine("Operand ") +
1239                                  llvm::Twine(it.index()) + " of operation \"" +
1240                                  name + "\" must be a Value (" + err.what() +
1241                                  ")")
1242                                     .str());
1243         }
1244       } else if (segmentSpec == -1) {
1245         // Unpack sequence by appending.
1246         try {
1247           if (std::get<0>(it.value()).is_none()) {
1248             // Treat it as an empty list.
1249             operandSegmentLengths.push_back(0);
1250           } else {
1251             // Unpack the list.
1252             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1253             for (py::object segmentItem : segment) {
1254               operands.push_back(py::cast<PyValue *>(segmentItem));
1255               if (!operands.back()) {
1256                 throw py::cast_error("contained a None item");
1257               }
1258             }
1259             operandSegmentLengths.push_back(segment.size());
1260           }
1261         } catch (std::exception &err) {
1262           // NOTE: Sloppy to be using a catch-all here, but there are at least
1263           // three different unrelated exceptions that can be thrown in the
1264           // above "casts". Just keep the scope above small and catch them all.
1265           throw py::value_error((llvm::Twine("Operand ") +
1266                                  llvm::Twine(it.index()) + " of operation \"" +
1267                                  name + "\" must be a Sequence of Values (" +
1268                                  err.what() + ")")
1269                                     .str());
1270         }
1271       } else {
1272         throw py::value_error("Unexpected segment spec");
1273       }
1274     }
1275   }
1276 
1277   // Merge operand/result segment lengths into attributes if needed.
1278   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1279     // Dup.
1280     if (attributes) {
1281       attributes = py::dict(*attributes);
1282     } else {
1283       attributes = py::dict();
1284     }
1285     if (attributes->contains("result_segment_sizes") ||
1286         attributes->contains("operand_segment_sizes")) {
1287       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1288                             "'operand_segment_sizes' attribute is unsupported. "
1289                             "Use Operation.create for such low-level access.");
1290     }
1291 
1292     // Add result_segment_sizes attribute.
1293     if (!resultSegmentLengths.empty()) {
1294       int64_t size = resultSegmentLengths.size();
1295       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1296           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1297           resultSegmentLengths.size(), resultSegmentLengths.data());
1298       (*attributes)["result_segment_sizes"] =
1299           PyAttribute(context, segmentLengthAttr);
1300     }
1301 
1302     // Add operand_segment_sizes attribute.
1303     if (!operandSegmentLengths.empty()) {
1304       int64_t size = operandSegmentLengths.size();
1305       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1306           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1307           operandSegmentLengths.size(), operandSegmentLengths.data());
1308       (*attributes)["operand_segment_sizes"] =
1309           PyAttribute(context, segmentLengthAttr);
1310     }
1311   }
1312 
1313   // Delegate to create.
1314   return PyOperation::create(std::move(name),
1315                              /*results=*/std::move(resultTypes),
1316                              /*operands=*/std::move(operands),
1317                              /*attributes=*/std::move(attributes),
1318                              /*successors=*/std::move(successors),
1319                              /*regions=*/*regions, location, maybeIp);
1320 }
1321 
1322 PyOpView::PyOpView(py::object operationObject)
1323     // Casting through the PyOperationBase base-class and then back to the
1324     // Operation lets us accept any PyOperationBase subclass.
1325     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1326       operationObject(operation.getRef().getObject()) {}
1327 
1328 py::object PyOpView::createRawSubclass(py::object userClass) {
1329   // This is... a little gross. The typical pattern is to have a pure python
1330   // class that extends OpView like:
1331   //   class AddFOp(_cext.ir.OpView):
1332   //     def __init__(self, loc, lhs, rhs):
1333   //       operation = loc.context.create_operation(
1334   //           "addf", lhs, rhs, results=[lhs.type])
1335   //       super().__init__(operation)
1336   //
1337   // I.e. The goal of the user facing type is to provide a nice constructor
1338   // that has complete freedom for the op under construction. This is at odds
1339   // with our other desire to sometimes create this object by just passing an
1340   // operation (to initialize the base class). We could do *arg and **kwargs
1341   // munging to try to make it work, but instead, we synthesize a new class
1342   // on the fly which extends this user class (AddFOp in this example) and
1343   // *give it* the base class's __init__ method, thus bypassing the
1344   // intermediate subclass's __init__ method entirely. While slightly,
1345   // underhanded, this is safe/legal because the type hierarchy has not changed
1346   // (we just added a new leaf) and we aren't mucking around with __new__.
1347   // Typically, this new class will be stored on the original as "_Raw" and will
1348   // be used for casts and other things that need a variant of the class that
1349   // is initialized purely from an operation.
1350   py::object parentMetaclass =
1351       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1352   py::dict attributes;
1353   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1354   // now.
1355   //   auto opViewType = py::type::of<PyOpView>();
1356   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1357   attributes["__init__"] = opViewType.attr("__init__");
1358   py::str origName = userClass.attr("__name__");
1359   py::str newName = py::str("_") + origName;
1360   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1361 }
1362 
1363 //------------------------------------------------------------------------------
1364 // PyInsertionPoint.
1365 //------------------------------------------------------------------------------
1366 
1367 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1368 
1369 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1370     : refOperation(beforeOperationBase.getOperation().getRef()),
1371       block((*refOperation)->getBlock()) {}
1372 
1373 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1374   PyOperation &operation = operationBase.getOperation();
1375   if (operation.isAttached())
1376     throw SetPyError(PyExc_ValueError,
1377                      "Attempt to insert operation that is already attached");
1378   block.getParentOperation()->checkValid();
1379   MlirOperation beforeOp = {nullptr};
1380   if (refOperation) {
1381     // Insert before operation.
1382     (*refOperation)->checkValid();
1383     beforeOp = (*refOperation)->get();
1384   } else {
1385     // Insert at end (before null) is only valid if the block does not
1386     // already end in a known terminator (violating this will cause assertion
1387     // failures later).
1388     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1389       throw py::index_error("Cannot insert operation at the end of a block "
1390                             "that already has a terminator. Did you mean to "
1391                             "use 'InsertionPoint.at_block_terminator(block)' "
1392                             "versus 'InsertionPoint(block)'?");
1393     }
1394   }
1395   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1396   operation.setAttached();
1397 }
1398 
1399 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1400   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1401   if (mlirOperationIsNull(firstOp)) {
1402     // Just insert at end.
1403     return PyInsertionPoint(block);
1404   }
1405 
1406   // Insert before first op.
1407   PyOperationRef firstOpRef = PyOperation::forOperation(
1408       block.getParentOperation()->getContext(), firstOp);
1409   return PyInsertionPoint{block, std::move(firstOpRef)};
1410 }
1411 
1412 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1413   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1414   if (mlirOperationIsNull(terminator))
1415     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1416   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1417       block.getParentOperation()->getContext(), terminator);
1418   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1419 }
1420 
1421 py::object PyInsertionPoint::contextEnter() {
1422   return PyThreadContextEntry::pushInsertionPoint(*this);
1423 }
1424 
1425 void PyInsertionPoint::contextExit(pybind11::object excType,
1426                                    pybind11::object excVal,
1427                                    pybind11::object excTb) {
1428   PyThreadContextEntry::popInsertionPoint(*this);
1429 }
1430 
1431 //------------------------------------------------------------------------------
1432 // PyAttribute.
1433 //------------------------------------------------------------------------------
1434 
1435 bool PyAttribute::operator==(const PyAttribute &other) {
1436   return mlirAttributeEqual(attr, other.attr);
1437 }
1438 
1439 py::object PyAttribute::getCapsule() {
1440   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1441 }
1442 
1443 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1444   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1445   if (mlirAttributeIsNull(rawAttr))
1446     throw py::error_already_set();
1447   return PyAttribute(
1448       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1449 }
1450 
1451 //------------------------------------------------------------------------------
1452 // PyNamedAttribute.
1453 //------------------------------------------------------------------------------
1454 
1455 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1456     : ownedName(new std::string(std::move(ownedName))) {
1457   namedAttr = mlirNamedAttributeGet(
1458       mlirIdentifierGet(mlirAttributeGetContext(attr),
1459                         toMlirStringRef(*this->ownedName)),
1460       attr);
1461 }
1462 
1463 //------------------------------------------------------------------------------
1464 // PyType.
1465 //------------------------------------------------------------------------------
1466 
1467 bool PyType::operator==(const PyType &other) {
1468   return mlirTypeEqual(type, other.type);
1469 }
1470 
1471 py::object PyType::getCapsule() {
1472   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1473 }
1474 
1475 PyType PyType::createFromCapsule(py::object capsule) {
1476   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1477   if (mlirTypeIsNull(rawType))
1478     throw py::error_already_set();
1479   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1480                 rawType);
1481 }
1482 
1483 //------------------------------------------------------------------------------
1484 // PyValue and subclases.
1485 //------------------------------------------------------------------------------
1486 
1487 pybind11::object PyValue::getCapsule() {
1488   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1489 }
1490 
1491 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1492   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1493   if (mlirValueIsNull(value))
1494     throw py::error_already_set();
1495   MlirOperation owner;
1496   if (mlirValueIsAOpResult(value))
1497     owner = mlirOpResultGetOwner(value);
1498   if (mlirValueIsABlockArgument(value))
1499     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1500   if (mlirOperationIsNull(owner))
1501     throw py::error_already_set();
1502   MlirContext ctx = mlirOperationGetContext(owner);
1503   PyOperationRef ownerRef =
1504       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1505   return PyValue(ownerRef, value);
1506 }
1507 
1508 namespace {
1509 /// CRTP base class for Python MLIR values that subclass Value and should be
1510 /// castable from it. The value hierarchy is one level deep and is not supposed
1511 /// to accommodate other levels unless core MLIR changes.
1512 template <typename DerivedTy>
1513 class PyConcreteValue : public PyValue {
1514 public:
1515   // Derived classes must define statics for:
1516   //   IsAFunctionTy isaFunction
1517   //   const char *pyClassName
1518   // and redefine bindDerived.
1519   using ClassTy = py::class_<DerivedTy, PyValue>;
1520   using IsAFunctionTy = bool (*)(MlirValue);
1521 
1522   PyConcreteValue() = default;
1523   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1524       : PyValue(operationRef, value) {}
1525   PyConcreteValue(PyValue &orig)
1526       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1527 
1528   /// Attempts to cast the original value to the derived type and throws on
1529   /// type mismatches.
1530   static MlirValue castFrom(PyValue &orig) {
1531     if (!DerivedTy::isaFunction(orig.get())) {
1532       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1533       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1534                                              DerivedTy::pyClassName +
1535                                              " (from " + origRepr + ")");
1536     }
1537     return orig.get();
1538   }
1539 
1540   /// Binds the Python module objects to functions of this class.
1541   static void bind(py::module &m) {
1542     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1543     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1544     DerivedTy::bindDerived(cls);
1545   }
1546 
1547   /// Implemented by derived classes to add methods to the Python subclass.
1548   static void bindDerived(ClassTy &m) {}
1549 };
1550 
1551 /// Python wrapper for MlirBlockArgument.
1552 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1553 public:
1554   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1555   static constexpr const char *pyClassName = "BlockArgument";
1556   using PyConcreteValue::PyConcreteValue;
1557 
1558   static void bindDerived(ClassTy &c) {
1559     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1560       return PyBlock(self.getParentOperation(),
1561                      mlirBlockArgumentGetOwner(self.get()));
1562     });
1563     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1564       return mlirBlockArgumentGetArgNumber(self.get());
1565     });
1566     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1567       return mlirBlockArgumentSetType(self.get(), type);
1568     });
1569   }
1570 };
1571 
1572 /// Python wrapper for MlirOpResult.
1573 class PyOpResult : public PyConcreteValue<PyOpResult> {
1574 public:
1575   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1576   static constexpr const char *pyClassName = "OpResult";
1577   using PyConcreteValue::PyConcreteValue;
1578 
1579   static void bindDerived(ClassTy &c) {
1580     c.def_property_readonly("owner", [](PyOpResult &self) {
1581       assert(
1582           mlirOperationEqual(self.getParentOperation()->get(),
1583                              mlirOpResultGetOwner(self.get())) &&
1584           "expected the owner of the value in Python to match that in the IR");
1585       return self.getParentOperation().getObject();
1586     });
1587     c.def_property_readonly("result_number", [](PyOpResult &self) {
1588       return mlirOpResultGetResultNumber(self.get());
1589     });
1590   }
1591 };
1592 
1593 /// Returns the list of types of the values held by container.
1594 template <typename Container>
1595 static std::vector<PyType> getValueTypes(Container &container,
1596                                          PyMlirContextRef &context) {
1597   std::vector<PyType> result;
1598   result.reserve(container.getNumElements());
1599   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1600     result.push_back(
1601         PyType(context, mlirValueGetType(container.getElement(i).get())));
1602   }
1603   return result;
1604 }
1605 
1606 /// A list of block arguments. Internally, these are stored as consecutive
1607 /// elements, random access is cheap. The argument list is associated with the
1608 /// operation that contains the block (detached blocks are not allowed in
1609 /// Python bindings) and extends its lifetime.
1610 class PyBlockArgumentList
1611     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1612 public:
1613   static constexpr const char *pyClassName = "BlockArgumentList";
1614 
1615   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1616                       intptr_t startIndex = 0, intptr_t length = -1,
1617                       intptr_t step = 1)
1618       : Sliceable(startIndex,
1619                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1620                   step),
1621         operation(std::move(operation)), block(block) {}
1622 
1623   /// Returns the number of arguments in the list.
1624   intptr_t getNumElements() {
1625     operation->checkValid();
1626     return mlirBlockGetNumArguments(block);
1627   }
1628 
1629   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1630   PyBlockArgument getElement(intptr_t pos) {
1631     MlirValue argument = mlirBlockGetArgument(block, pos);
1632     return PyBlockArgument(operation, argument);
1633   }
1634 
1635   /// Returns a sublist of this list.
1636   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1637                             intptr_t step) {
1638     return PyBlockArgumentList(operation, block, startIndex, length, step);
1639   }
1640 
1641   static void bindDerived(ClassTy &c) {
1642     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1643       return getValueTypes(self, self.operation->getContext());
1644     });
1645   }
1646 
1647 private:
1648   PyOperationRef operation;
1649   MlirBlock block;
1650 };
1651 
1652 /// A list of operation operands. Internally, these are stored as consecutive
1653 /// elements, random access is cheap. The result list is associated with the
1654 /// operation whose results these are, and extends the lifetime of this
1655 /// operation.
1656 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1657 public:
1658   static constexpr const char *pyClassName = "OpOperandList";
1659 
1660   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1661                   intptr_t length = -1, intptr_t step = 1)
1662       : Sliceable(startIndex,
1663                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1664                                : length,
1665                   step),
1666         operation(operation) {}
1667 
1668   intptr_t getNumElements() {
1669     operation->checkValid();
1670     return mlirOperationGetNumOperands(operation->get());
1671   }
1672 
1673   PyValue getElement(intptr_t pos) {
1674     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1675     MlirOperation owner;
1676     if (mlirValueIsAOpResult(operand))
1677       owner = mlirOpResultGetOwner(operand);
1678     else if (mlirValueIsABlockArgument(operand))
1679       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1680     else
1681       assert(false && "Value must be an block arg or op result.");
1682     PyOperationRef pyOwner =
1683         PyOperation::forOperation(operation->getContext(), owner);
1684     return PyValue(pyOwner, operand);
1685   }
1686 
1687   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1688     return PyOpOperandList(operation, startIndex, length, step);
1689   }
1690 
1691   void dunderSetItem(intptr_t index, PyValue value) {
1692     index = wrapIndex(index);
1693     mlirOperationSetOperand(operation->get(), index, value.get());
1694   }
1695 
1696   static void bindDerived(ClassTy &c) {
1697     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1698   }
1699 
1700 private:
1701   PyOperationRef operation;
1702 };
1703 
1704 /// A list of operation results. Internally, these are stored as consecutive
1705 /// elements, random access is cheap. The result list is associated with the
1706 /// operation whose results these are, and extends the lifetime of this
1707 /// operation.
1708 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1709 public:
1710   static constexpr const char *pyClassName = "OpResultList";
1711 
1712   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1713                  intptr_t length = -1, intptr_t step = 1)
1714       : Sliceable(startIndex,
1715                   length == -1 ? mlirOperationGetNumResults(operation->get())
1716                                : length,
1717                   step),
1718         operation(operation) {}
1719 
1720   intptr_t getNumElements() {
1721     operation->checkValid();
1722     return mlirOperationGetNumResults(operation->get());
1723   }
1724 
1725   PyOpResult getElement(intptr_t index) {
1726     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1727     return PyOpResult(value);
1728   }
1729 
1730   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1731     return PyOpResultList(operation, startIndex, length, step);
1732   }
1733 
1734   static void bindDerived(ClassTy &c) {
1735     c.def_property_readonly("types", [](PyOpResultList &self) {
1736       return getValueTypes(self, self.operation->getContext());
1737     });
1738   }
1739 
1740 private:
1741   PyOperationRef operation;
1742 };
1743 
1744 /// A list of operation attributes. Can be indexed by name, producing
1745 /// attributes, or by index, producing named attributes.
1746 class PyOpAttributeMap {
1747 public:
1748   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1749 
1750   PyAttribute dunderGetItemNamed(const std::string &name) {
1751     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1752                                                          toMlirStringRef(name));
1753     if (mlirAttributeIsNull(attr)) {
1754       throw SetPyError(PyExc_KeyError,
1755                        "attempt to access a non-existent attribute");
1756     }
1757     return PyAttribute(operation->getContext(), attr);
1758   }
1759 
1760   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1761     if (index < 0 || index >= dunderLen()) {
1762       throw SetPyError(PyExc_IndexError,
1763                        "attempt to access out of bounds attribute");
1764     }
1765     MlirNamedAttribute namedAttr =
1766         mlirOperationGetAttribute(operation->get(), index);
1767     return PyNamedAttribute(
1768         namedAttr.attribute,
1769         std::string(mlirIdentifierStr(namedAttr.name).data));
1770   }
1771 
1772   void dunderSetItem(const std::string &name, PyAttribute attr) {
1773     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1774                                     attr);
1775   }
1776 
1777   void dunderDelItem(const std::string &name) {
1778     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1779                                                      toMlirStringRef(name));
1780     if (!removed)
1781       throw SetPyError(PyExc_KeyError,
1782                        "attempt to delete a non-existent attribute");
1783   }
1784 
1785   intptr_t dunderLen() {
1786     return mlirOperationGetNumAttributes(operation->get());
1787   }
1788 
1789   bool dunderContains(const std::string &name) {
1790     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1791         operation->get(), toMlirStringRef(name)));
1792   }
1793 
1794   static void bind(py::module &m) {
1795     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1796         .def("__contains__", &PyOpAttributeMap::dunderContains)
1797         .def("__len__", &PyOpAttributeMap::dunderLen)
1798         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1799         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1800         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1801         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1802   }
1803 
1804 private:
1805   PyOperationRef operation;
1806 };
1807 
1808 } // end namespace
1809 
1810 //------------------------------------------------------------------------------
1811 // Populates the core exports of the 'ir' submodule.
1812 //------------------------------------------------------------------------------
1813 
1814 void mlir::python::populateIRCore(py::module &m) {
1815   //----------------------------------------------------------------------------
1816   // Mapping of MlirContext.
1817   //----------------------------------------------------------------------------
1818   py::class_<PyMlirContext>(m, "Context", py::module_local())
1819       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1820       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1821       .def("_get_context_again",
1822            [](PyMlirContext &self) {
1823              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1824              return ref.releaseObject();
1825            })
1826       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1827       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1828       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1829                              &PyMlirContext::getCapsule)
1830       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1831       .def("__enter__", &PyMlirContext::contextEnter)
1832       .def("__exit__", &PyMlirContext::contextExit)
1833       .def_property_readonly_static(
1834           "current",
1835           [](py::object & /*class*/) {
1836             auto *context = PyThreadContextEntry::getDefaultContext();
1837             if (!context)
1838               throw SetPyError(PyExc_ValueError, "No current Context");
1839             return context;
1840           },
1841           "Gets the Context bound to the current thread or raises ValueError")
1842       .def_property_readonly(
1843           "dialects",
1844           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1845           "Gets a container for accessing dialects by name")
1846       .def_property_readonly(
1847           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1848           "Alias for 'dialect'")
1849       .def(
1850           "get_dialect_descriptor",
1851           [=](PyMlirContext &self, std::string &name) {
1852             MlirDialect dialect = mlirContextGetOrLoadDialect(
1853                 self.get(), {name.data(), name.size()});
1854             if (mlirDialectIsNull(dialect)) {
1855               throw SetPyError(PyExc_ValueError,
1856                                Twine("Dialect '") + name + "' not found");
1857             }
1858             return PyDialectDescriptor(self.getRef(), dialect);
1859           },
1860           "Gets or loads a dialect by name, returning its descriptor object")
1861       .def_property(
1862           "allow_unregistered_dialects",
1863           [](PyMlirContext &self) -> bool {
1864             return mlirContextGetAllowUnregisteredDialects(self.get());
1865           },
1866           [](PyMlirContext &self, bool value) {
1867             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1868           })
1869       .def("enable_multithreading",
1870            [](PyMlirContext &self, bool enable) {
1871              mlirContextEnableMultithreading(self.get(), enable);
1872            })
1873       .def("is_registered_operation",
1874            [](PyMlirContext &self, std::string &name) {
1875              return mlirContextIsRegisteredOperation(
1876                  self.get(), MlirStringRef{name.data(), name.size()});
1877            });
1878 
1879   //----------------------------------------------------------------------------
1880   // Mapping of PyDialectDescriptor
1881   //----------------------------------------------------------------------------
1882   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1883       .def_property_readonly("namespace",
1884                              [](PyDialectDescriptor &self) {
1885                                MlirStringRef ns =
1886                                    mlirDialectGetNamespace(self.get());
1887                                return py::str(ns.data, ns.length);
1888                              })
1889       .def("__repr__", [](PyDialectDescriptor &self) {
1890         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1891         std::string repr("<DialectDescriptor ");
1892         repr.append(ns.data, ns.length);
1893         repr.append(">");
1894         return repr;
1895       });
1896 
1897   //----------------------------------------------------------------------------
1898   // Mapping of PyDialects
1899   //----------------------------------------------------------------------------
1900   py::class_<PyDialects>(m, "Dialects", py::module_local())
1901       .def("__getitem__",
1902            [=](PyDialects &self, std::string keyName) {
1903              MlirDialect dialect =
1904                  self.getDialectForKey(keyName, /*attrError=*/false);
1905              py::object descriptor =
1906                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1907              return createCustomDialectWrapper(keyName, std::move(descriptor));
1908            })
1909       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1910         MlirDialect dialect =
1911             self.getDialectForKey(attrName, /*attrError=*/true);
1912         py::object descriptor =
1913             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1914         return createCustomDialectWrapper(attrName, std::move(descriptor));
1915       });
1916 
1917   //----------------------------------------------------------------------------
1918   // Mapping of PyDialect
1919   //----------------------------------------------------------------------------
1920   py::class_<PyDialect>(m, "Dialect", py::module_local())
1921       .def(py::init<py::object>(), "descriptor")
1922       .def_property_readonly(
1923           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1924       .def("__repr__", [](py::object self) {
1925         auto clazz = self.attr("__class__");
1926         return py::str("<Dialect ") +
1927                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1928                clazz.attr("__module__") + py::str(".") +
1929                clazz.attr("__name__") + py::str(")>");
1930       });
1931 
1932   //----------------------------------------------------------------------------
1933   // Mapping of Location
1934   //----------------------------------------------------------------------------
1935   py::class_<PyLocation>(m, "Location", py::module_local())
1936       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1937       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1938       .def("__enter__", &PyLocation::contextEnter)
1939       .def("__exit__", &PyLocation::contextExit)
1940       .def("__eq__",
1941            [](PyLocation &self, PyLocation &other) -> bool {
1942              return mlirLocationEqual(self, other);
1943            })
1944       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1945       .def_property_readonly_static(
1946           "current",
1947           [](py::object & /*class*/) {
1948             auto *loc = PyThreadContextEntry::getDefaultLocation();
1949             if (!loc)
1950               throw SetPyError(PyExc_ValueError, "No current Location");
1951             return loc;
1952           },
1953           "Gets the Location bound to the current thread or raises ValueError")
1954       .def_static(
1955           "unknown",
1956           [](DefaultingPyMlirContext context) {
1957             return PyLocation(context->getRef(),
1958                               mlirLocationUnknownGet(context->get()));
1959           },
1960           py::arg("context") = py::none(),
1961           "Gets a Location representing an unknown location")
1962       .def_static(
1963           "file",
1964           [](std::string filename, int line, int col,
1965              DefaultingPyMlirContext context) {
1966             return PyLocation(
1967                 context->getRef(),
1968                 mlirLocationFileLineColGet(
1969                     context->get(), toMlirStringRef(filename), line, col));
1970           },
1971           py::arg("filename"), py::arg("line"), py::arg("col"),
1972           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1973       .def_property_readonly(
1974           "context",
1975           [](PyLocation &self) { return self.getContext().getObject(); },
1976           "Context that owns the Location")
1977       .def("__repr__", [](PyLocation &self) {
1978         PyPrintAccumulator printAccum;
1979         mlirLocationPrint(self, printAccum.getCallback(),
1980                           printAccum.getUserData());
1981         return printAccum.join();
1982       });
1983 
1984   //----------------------------------------------------------------------------
1985   // Mapping of Module
1986   //----------------------------------------------------------------------------
1987   py::class_<PyModule>(m, "Module", py::module_local())
1988       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1989       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1990       .def_static(
1991           "parse",
1992           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1993             MlirModule module = mlirModuleCreateParse(
1994                 context->get(), toMlirStringRef(moduleAsm));
1995             // TODO: Rework error reporting once diagnostic engine is exposed
1996             // in C API.
1997             if (mlirModuleIsNull(module)) {
1998               throw SetPyError(
1999                   PyExc_ValueError,
2000                   "Unable to parse module assembly (see diagnostics)");
2001             }
2002             return PyModule::forModule(module).releaseObject();
2003           },
2004           py::arg("asm"), py::arg("context") = py::none(),
2005           kModuleParseDocstring)
2006       .def_static(
2007           "create",
2008           [](DefaultingPyLocation loc) {
2009             MlirModule module = mlirModuleCreateEmpty(loc);
2010             return PyModule::forModule(module).releaseObject();
2011           },
2012           py::arg("loc") = py::none(), "Creates an empty module")
2013       .def_property_readonly(
2014           "context",
2015           [](PyModule &self) { return self.getContext().getObject(); },
2016           "Context that created the Module")
2017       .def_property_readonly(
2018           "operation",
2019           [](PyModule &self) {
2020             return PyOperation::forOperation(self.getContext(),
2021                                              mlirModuleGetOperation(self.get()),
2022                                              self.getRef().releaseObject())
2023                 .releaseObject();
2024           },
2025           "Accesses the module as an operation")
2026       .def_property_readonly(
2027           "body",
2028           [](PyModule &self) {
2029             PyOperationRef module_op = PyOperation::forOperation(
2030                 self.getContext(), mlirModuleGetOperation(self.get()),
2031                 self.getRef().releaseObject());
2032             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2033             return returnBlock;
2034           },
2035           "Return the block for this module")
2036       .def(
2037           "dump",
2038           [](PyModule &self) {
2039             mlirOperationDump(mlirModuleGetOperation(self.get()));
2040           },
2041           kDumpDocstring)
2042       .def(
2043           "__str__",
2044           [](PyModule &self) {
2045             MlirOperation operation = mlirModuleGetOperation(self.get());
2046             PyPrintAccumulator printAccum;
2047             mlirOperationPrint(operation, printAccum.getCallback(),
2048                                printAccum.getUserData());
2049             return printAccum.join();
2050           },
2051           kOperationStrDunderDocstring);
2052 
2053   //----------------------------------------------------------------------------
2054   // Mapping of Operation.
2055   //----------------------------------------------------------------------------
2056   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2057       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2058                              [](PyOperationBase &self) {
2059                                return self.getOperation().getCapsule();
2060                              })
2061       .def("__eq__",
2062            [](PyOperationBase &self, PyOperationBase &other) {
2063              return &self.getOperation() == &other.getOperation();
2064            })
2065       .def("__eq__",
2066            [](PyOperationBase &self, py::object other) { return false; })
2067       .def_property_readonly("attributes",
2068                              [](PyOperationBase &self) {
2069                                return PyOpAttributeMap(
2070                                    self.getOperation().getRef());
2071                              })
2072       .def_property_readonly("operands",
2073                              [](PyOperationBase &self) {
2074                                return PyOpOperandList(
2075                                    self.getOperation().getRef());
2076                              })
2077       .def_property_readonly("regions",
2078                              [](PyOperationBase &self) {
2079                                return PyRegionList(
2080                                    self.getOperation().getRef());
2081                              })
2082       .def_property_readonly(
2083           "results",
2084           [](PyOperationBase &self) {
2085             return PyOpResultList(self.getOperation().getRef());
2086           },
2087           "Returns the list of Operation results.")
2088       .def_property_readonly(
2089           "result",
2090           [](PyOperationBase &self) {
2091             auto &operation = self.getOperation();
2092             auto numResults = mlirOperationGetNumResults(operation);
2093             if (numResults != 1) {
2094               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2095               throw SetPyError(
2096                   PyExc_ValueError,
2097                   Twine("Cannot call .result on operation ") +
2098                       StringRef(name.data, name.length) + " which has " +
2099                       Twine(numResults) +
2100                       " results (it is only valid for operations with a "
2101                       "single result)");
2102             }
2103             return PyOpResult(operation.getRef(),
2104                               mlirOperationGetResult(operation, 0));
2105           },
2106           "Shortcut to get an op result if it has only one (throws an error "
2107           "otherwise).")
2108       .def("__iter__",
2109            [](PyOperationBase &self) {
2110              return PyRegionIterator(self.getOperation().getRef());
2111            })
2112       .def(
2113           "__str__",
2114           [](PyOperationBase &self) {
2115             return self.getAsm(/*binary=*/false,
2116                                /*largeElementsLimit=*/llvm::None,
2117                                /*enableDebugInfo=*/false,
2118                                /*prettyDebugInfo=*/false,
2119                                /*printGenericOpForm=*/false,
2120                                /*useLocalScope=*/false);
2121           },
2122           "Returns the assembly form of the operation.")
2123       .def("print", &PyOperationBase::print,
2124            // Careful: Lots of arguments must match up with print method.
2125            py::arg("file") = py::none(), py::arg("binary") = false,
2126            py::arg("large_elements_limit") = py::none(),
2127            py::arg("enable_debug_info") = false,
2128            py::arg("pretty_debug_info") = false,
2129            py::arg("print_generic_op_form") = false,
2130            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2131       .def("get_asm", &PyOperationBase::getAsm,
2132            // Careful: Lots of arguments must match up with get_asm method.
2133            py::arg("binary") = false,
2134            py::arg("large_elements_limit") = py::none(),
2135            py::arg("enable_debug_info") = false,
2136            py::arg("pretty_debug_info") = false,
2137            py::arg("print_generic_op_form") = false,
2138            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2139       .def(
2140           "verify",
2141           [](PyOperationBase &self) {
2142             return mlirOperationVerify(self.getOperation());
2143           },
2144           "Verify the operation and return true if it passes, false if it "
2145           "fails.");
2146 
2147   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2148       .def_static("create", &PyOperation::create, py::arg("name"),
2149                   py::arg("results") = py::none(),
2150                   py::arg("operands") = py::none(),
2151                   py::arg("attributes") = py::none(),
2152                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2153                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2154                   kOperationCreateDocstring)
2155       .def_property_readonly("parent",
2156                              [](PyOperation &self) -> py::object {
2157                                auto parent = self.getParentOperation();
2158                                if (parent)
2159                                  return parent->getObject();
2160                                return py::none();
2161                              })
2162       .def("erase", &PyOperation::erase)
2163       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2164                              &PyOperation::getCapsule)
2165       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2166       .def_property_readonly("name",
2167                              [](PyOperation &self) {
2168                                self.checkValid();
2169                                MlirOperation operation = self.get();
2170                                MlirStringRef name = mlirIdentifierStr(
2171                                    mlirOperationGetName(operation));
2172                                return py::str(name.data, name.length);
2173                              })
2174       .def_property_readonly(
2175           "context",
2176           [](PyOperation &self) {
2177             self.checkValid();
2178             return self.getContext().getObject();
2179           },
2180           "Context that owns the Operation")
2181       .def_property_readonly("opview", &PyOperation::createOpView);
2182 
2183   auto opViewClass =
2184       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2185           .def(py::init<py::object>())
2186           .def_property_readonly("operation", &PyOpView::getOperationObject)
2187           .def_property_readonly(
2188               "context",
2189               [](PyOpView &self) {
2190                 return self.getOperation().getContext().getObject();
2191               },
2192               "Context that owns the Operation")
2193           .def("__str__", [](PyOpView &self) {
2194             return py::str(self.getOperationObject());
2195           });
2196   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2197   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2198   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2199   opViewClass.attr("build_generic") = classmethod(
2200       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2201       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2202       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2203       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2204       "Builds a specific, generated OpView based on class level attributes.");
2205 
2206   //----------------------------------------------------------------------------
2207   // Mapping of PyRegion.
2208   //----------------------------------------------------------------------------
2209   py::class_<PyRegion>(m, "Region", py::module_local())
2210       .def_property_readonly(
2211           "blocks",
2212           [](PyRegion &self) {
2213             return PyBlockList(self.getParentOperation(), self.get());
2214           },
2215           "Returns a forward-optimized sequence of blocks.")
2216       .def(
2217           "__iter__",
2218           [](PyRegion &self) {
2219             self.checkValid();
2220             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2221             return PyBlockIterator(self.getParentOperation(), firstBlock);
2222           },
2223           "Iterates over blocks in the region.")
2224       .def("__eq__",
2225            [](PyRegion &self, PyRegion &other) {
2226              return self.get().ptr == other.get().ptr;
2227            })
2228       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2229 
2230   //----------------------------------------------------------------------------
2231   // Mapping of PyBlock.
2232   //----------------------------------------------------------------------------
2233   py::class_<PyBlock>(m, "Block", py::module_local())
2234       .def_property_readonly(
2235           "owner",
2236           [](PyBlock &self) {
2237             return self.getParentOperation()->createOpView();
2238           },
2239           "Returns the owning operation of this block.")
2240       .def_property_readonly(
2241           "region",
2242           [](PyBlock &self) {
2243             MlirRegion region = mlirBlockGetParentRegion(self.get());
2244             return PyRegion(self.getParentOperation(), region);
2245           },
2246           "Returns the owning region of this block.")
2247       .def_property_readonly(
2248           "arguments",
2249           [](PyBlock &self) {
2250             return PyBlockArgumentList(self.getParentOperation(), self.get());
2251           },
2252           "Returns a list of block arguments.")
2253       .def_property_readonly(
2254           "operations",
2255           [](PyBlock &self) {
2256             return PyOperationList(self.getParentOperation(), self.get());
2257           },
2258           "Returns a forward-optimized sequence of operations.")
2259       .def(
2260           "create_before",
2261           [](PyBlock &self, py::args pyArgTypes) {
2262             self.checkValid();
2263             llvm::SmallVector<MlirType, 4> argTypes;
2264             argTypes.reserve(pyArgTypes.size());
2265             for (auto &pyArg : pyArgTypes) {
2266               argTypes.push_back(pyArg.cast<PyType &>());
2267             }
2268 
2269             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2270             MlirRegion region = mlirBlockGetParentRegion(self.get());
2271             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2272             return PyBlock(self.getParentOperation(), block);
2273           },
2274           "Creates and returns a new Block before this block "
2275           "(with given argument types).")
2276       .def(
2277           "create_after",
2278           [](PyBlock &self, py::args pyArgTypes) {
2279             self.checkValid();
2280             llvm::SmallVector<MlirType, 4> argTypes;
2281             argTypes.reserve(pyArgTypes.size());
2282             for (auto &pyArg : pyArgTypes) {
2283               argTypes.push_back(pyArg.cast<PyType &>());
2284             }
2285 
2286             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2287             MlirRegion region = mlirBlockGetParentRegion(self.get());
2288             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2289             return PyBlock(self.getParentOperation(), block);
2290           },
2291           "Creates and returns a new Block after this block "
2292           "(with given argument types).")
2293       .def(
2294           "__iter__",
2295           [](PyBlock &self) {
2296             self.checkValid();
2297             MlirOperation firstOperation =
2298                 mlirBlockGetFirstOperation(self.get());
2299             return PyOperationIterator(self.getParentOperation(),
2300                                        firstOperation);
2301           },
2302           "Iterates over operations in the block.")
2303       .def("__eq__",
2304            [](PyBlock &self, PyBlock &other) {
2305              return self.get().ptr == other.get().ptr;
2306            })
2307       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2308       .def(
2309           "__str__",
2310           [](PyBlock &self) {
2311             self.checkValid();
2312             PyPrintAccumulator printAccum;
2313             mlirBlockPrint(self.get(), printAccum.getCallback(),
2314                            printAccum.getUserData());
2315             return printAccum.join();
2316           },
2317           "Returns the assembly form of the block.");
2318 
2319   //----------------------------------------------------------------------------
2320   // Mapping of PyInsertionPoint.
2321   //----------------------------------------------------------------------------
2322 
2323   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2324       .def(py::init<PyBlock &>(), py::arg("block"),
2325            "Inserts after the last operation but still inside the block.")
2326       .def("__enter__", &PyInsertionPoint::contextEnter)
2327       .def("__exit__", &PyInsertionPoint::contextExit)
2328       .def_property_readonly_static(
2329           "current",
2330           [](py::object & /*class*/) {
2331             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2332             if (!ip)
2333               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2334             return ip;
2335           },
2336           "Gets the InsertionPoint bound to the current thread or raises "
2337           "ValueError if none has been set")
2338       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2339            "Inserts before a referenced operation.")
2340       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2341                   py::arg("block"), "Inserts at the beginning of the block.")
2342       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2343                   py::arg("block"), "Inserts before the block terminator.")
2344       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2345            "Inserts an operation.")
2346       .def_property_readonly(
2347           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2348           "Returns the block that this InsertionPoint points to.");
2349 
2350   //----------------------------------------------------------------------------
2351   // Mapping of PyAttribute.
2352   //----------------------------------------------------------------------------
2353   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2354       // Delegate to the PyAttribute copy constructor, which will also lifetime
2355       // extend the backing context which owns the MlirAttribute.
2356       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2357            "Casts the passed attribute to the generic Attribute")
2358       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2359                              &PyAttribute::getCapsule)
2360       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2361       .def_static(
2362           "parse",
2363           [](std::string attrSpec, DefaultingPyMlirContext context) {
2364             MlirAttribute type = mlirAttributeParseGet(
2365                 context->get(), toMlirStringRef(attrSpec));
2366             // TODO: Rework error reporting once diagnostic engine is exposed
2367             // in C API.
2368             if (mlirAttributeIsNull(type)) {
2369               throw SetPyError(PyExc_ValueError,
2370                                Twine("Unable to parse attribute: '") +
2371                                    attrSpec + "'");
2372             }
2373             return PyAttribute(context->getRef(), type);
2374           },
2375           py::arg("asm"), py::arg("context") = py::none(),
2376           "Parses an attribute from an assembly form")
2377       .def_property_readonly(
2378           "context",
2379           [](PyAttribute &self) { return self.getContext().getObject(); },
2380           "Context that owns the Attribute")
2381       .def_property_readonly("type",
2382                              [](PyAttribute &self) {
2383                                return PyType(self.getContext()->getRef(),
2384                                              mlirAttributeGetType(self));
2385                              })
2386       .def(
2387           "get_named",
2388           [](PyAttribute &self, std::string name) {
2389             return PyNamedAttribute(self, std::move(name));
2390           },
2391           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2392       .def("__eq__",
2393            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2394       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2395       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2396       .def(
2397           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2398           kDumpDocstring)
2399       .def(
2400           "__str__",
2401           [](PyAttribute &self) {
2402             PyPrintAccumulator printAccum;
2403             mlirAttributePrint(self, printAccum.getCallback(),
2404                                printAccum.getUserData());
2405             return printAccum.join();
2406           },
2407           "Returns the assembly form of the Attribute.")
2408       .def("__repr__", [](PyAttribute &self) {
2409         // Generally, assembly formats are not printed for __repr__ because
2410         // this can cause exceptionally long debug output and exceptions.
2411         // However, attribute values are generally considered useful and are
2412         // printed. This may need to be re-evaluated if debug dumps end up
2413         // being excessive.
2414         PyPrintAccumulator printAccum;
2415         printAccum.parts.append("Attribute(");
2416         mlirAttributePrint(self, printAccum.getCallback(),
2417                            printAccum.getUserData());
2418         printAccum.parts.append(")");
2419         return printAccum.join();
2420       });
2421 
2422   //----------------------------------------------------------------------------
2423   // Mapping of PyNamedAttribute
2424   //----------------------------------------------------------------------------
2425   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2426       .def("__repr__",
2427            [](PyNamedAttribute &self) {
2428              PyPrintAccumulator printAccum;
2429              printAccum.parts.append("NamedAttribute(");
2430              printAccum.parts.append(
2431                  mlirIdentifierStr(self.namedAttr.name).data);
2432              printAccum.parts.append("=");
2433              mlirAttributePrint(self.namedAttr.attribute,
2434                                 printAccum.getCallback(),
2435                                 printAccum.getUserData());
2436              printAccum.parts.append(")");
2437              return printAccum.join();
2438            })
2439       .def_property_readonly(
2440           "name",
2441           [](PyNamedAttribute &self) {
2442             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2443                            mlirIdentifierStr(self.namedAttr.name).length);
2444           },
2445           "The name of the NamedAttribute binding")
2446       .def_property_readonly(
2447           "attr",
2448           [](PyNamedAttribute &self) {
2449             // TODO: When named attribute is removed/refactored, also remove
2450             // this constructor (it does an inefficient table lookup).
2451             auto contextRef = PyMlirContext::forContext(
2452                 mlirAttributeGetContext(self.namedAttr.attribute));
2453             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2454           },
2455           py::keep_alive<0, 1>(),
2456           "The underlying generic attribute of the NamedAttribute binding");
2457 
2458   //----------------------------------------------------------------------------
2459   // Mapping of PyType.
2460   //----------------------------------------------------------------------------
2461   py::class_<PyType>(m, "Type", py::module_local())
2462       // Delegate to the PyType copy constructor, which will also lifetime
2463       // extend the backing context which owns the MlirType.
2464       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2465            "Casts the passed type to the generic Type")
2466       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2467       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2468       .def_static(
2469           "parse",
2470           [](std::string typeSpec, DefaultingPyMlirContext context) {
2471             MlirType type =
2472                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2473             // TODO: Rework error reporting once diagnostic engine is exposed
2474             // in C API.
2475             if (mlirTypeIsNull(type)) {
2476               throw SetPyError(PyExc_ValueError,
2477                                Twine("Unable to parse type: '") + typeSpec +
2478                                    "'");
2479             }
2480             return PyType(context->getRef(), type);
2481           },
2482           py::arg("asm"), py::arg("context") = py::none(),
2483           kContextParseTypeDocstring)
2484       .def_property_readonly(
2485           "context", [](PyType &self) { return self.getContext().getObject(); },
2486           "Context that owns the Type")
2487       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2488       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2489       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2490       .def(
2491           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2492       .def(
2493           "__str__",
2494           [](PyType &self) {
2495             PyPrintAccumulator printAccum;
2496             mlirTypePrint(self, printAccum.getCallback(),
2497                           printAccum.getUserData());
2498             return printAccum.join();
2499           },
2500           "Returns the assembly form of the type.")
2501       .def("__repr__", [](PyType &self) {
2502         // Generally, assembly formats are not printed for __repr__ because
2503         // this can cause exceptionally long debug output and exceptions.
2504         // However, types are an exception as they typically have compact
2505         // assembly forms and printing them is useful.
2506         PyPrintAccumulator printAccum;
2507         printAccum.parts.append("Type(");
2508         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2509         printAccum.parts.append(")");
2510         return printAccum.join();
2511       });
2512 
2513   //----------------------------------------------------------------------------
2514   // Mapping of Value.
2515   //----------------------------------------------------------------------------
2516   py::class_<PyValue>(m, "Value", py::module_local())
2517       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2518       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2519       .def_property_readonly(
2520           "context",
2521           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2522           "Context in which the value lives.")
2523       .def(
2524           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2525           kDumpDocstring)
2526       .def_property_readonly(
2527           "owner",
2528           [](PyValue &self) {
2529             assert(mlirOperationEqual(self.getParentOperation()->get(),
2530                                       mlirOpResultGetOwner(self.get())) &&
2531                    "expected the owner of the value in Python to match that in "
2532                    "the IR");
2533             return self.getParentOperation().getObject();
2534           })
2535       .def("__eq__",
2536            [](PyValue &self, PyValue &other) {
2537              return self.get().ptr == other.get().ptr;
2538            })
2539       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2540       .def(
2541           "__str__",
2542           [](PyValue &self) {
2543             PyPrintAccumulator printAccum;
2544             printAccum.parts.append("Value(");
2545             mlirValuePrint(self.get(), printAccum.getCallback(),
2546                            printAccum.getUserData());
2547             printAccum.parts.append(")");
2548             return printAccum.join();
2549           },
2550           kValueDunderStrDocstring)
2551       .def_property_readonly("type", [](PyValue &self) {
2552         return PyType(self.getParentOperation()->getContext(),
2553                       mlirValueGetType(self.get()));
2554       });
2555   PyBlockArgument::bind(m);
2556   PyOpResult::bind(m);
2557 
2558   // Container bindings.
2559   PyBlockArgumentList::bind(m);
2560   PyBlockIterator::bind(m);
2561   PyBlockList::bind(m);
2562   PyOperationIterator::bind(m);
2563   PyOperationList::bind(m);
2564   PyOpAttributeMap::bind(m);
2565   PyOpOperandList::bind(m);
2566   PyOpResultList::bind(m);
2567   PyRegionIterator::bind(m);
2568   PyRegionList::bind(m);
2569 
2570   // Debug bindings.
2571   PyGlobalDebugFlag::bind(m);
2572 }
2573