1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 #include <utility>
25 
26 namespace py = pybind11;
27 using namespace mlir;
28 using namespace mlir::python;
29 
30 using llvm::SmallVector;
31 using llvm::StringRef;
32 using llvm::Twine;
33 
34 //------------------------------------------------------------------------------
35 // Docstrings (trivial, non-duplicated docstrings are included inline).
36 //------------------------------------------------------------------------------
37 
38 static const char kContextParseTypeDocstring[] =
39     R"(Parses the assembly form of a type.
40 
41 Returns a Type object or raises a ValueError if the type cannot be parsed.
42 
43 See also: https://mlir.llvm.org/docs/LangRef/#type-system
44 )";
45 
46 static const char kContextGetCallSiteLocationDocstring[] =
47     R"(Gets a Location representing a caller and callsite)";
48 
49 static const char kContextGetFileLocationDocstring[] =
50     R"(Gets a Location representing a file, line and column)";
51 
52 static const char kContextGetFusedLocationDocstring[] =
53     R"(Gets a Location representing a fused location with optional metadata)";
54 
55 static const char kContextGetNameLocationDocString[] =
56     R"(Gets a Location representing a named location with optional child location)";
57 
58 static const char kModuleParseDocstring[] =
59     R"(Parses a module's assembly format from a string.
60 
61 Returns a new MlirModule or raises a ValueError if the parsing fails.
62 
63 See also: https://mlir.llvm.org/docs/LangRef/
64 )";
65 
66 static const char kOperationCreateDocstring[] =
67     R"(Creates a new operation.
68 
69 Args:
70   name: Operation name (e.g. "dialect.operation").
71   results: Sequence of Type representing op result types.
72   attributes: Dict of str:Attribute.
73   successors: List of Block for the operation's successors.
74   regions: Number of regions to create.
75   location: A Location object (defaults to resolve from context manager).
76   ip: An InsertionPoint (defaults to resolve from context manager or set to
77     False to disable insertion, even with an insertion point set in the
78     context manager).
79 Returns:
80   A new "detached" Operation object. Detached operations can be added
81   to blocks, which causes them to become "attached."
82 )";
83 
84 static const char kOperationPrintDocstring[] =
85     R"(Prints the assembly form of the operation to a file like object.
86 
87 Args:
88   file: The file like object to write to. Defaults to sys.stdout.
89   binary: Whether to write bytes (True) or str (False). Defaults to False.
90   large_elements_limit: Whether to elide elements attributes above this
91     number of elements. Defaults to None (no limit).
92   enable_debug_info: Whether to print debug/location information. Defaults
93     to False.
94   pretty_debug_info: Whether to format debug information for easier reading
95     by a human (warning: the result is unparseable).
96   print_generic_op_form: Whether to print the generic assembly forms of all
97     ops. Defaults to False.
98   use_local_Scope: Whether to print in a way that is more optimized for
99     multi-threaded access but may not be consistent with how the overall
100     module prints.
101   assume_verified: By default, if not printing generic form, the verifier
102     will be run and if it fails, generic form will be printed with a comment
103     about failed verification. While a reasonable default for interactive use,
104     for systematic use, it is often better for the caller to verify explicitly
105     and report failures in a more robust fashion. Set this to True if doing this
106     in order to avoid running a redundant verification. If the IR is actually
107     invalid, behavior is undefined.
108 )";
109 
110 static const char kOperationGetAsmDocstring[] =
111     R"(Gets the assembly form of the operation with all options available.
112 
113 Args:
114   binary: Whether to return a bytes (True) or str (False) object. Defaults to
115     False.
116   ... others ...: See the print() method for common keyword arguments for
117     configuring the printout.
118 Returns:
119   Either a bytes or str object, depending on the setting of the 'binary'
120   argument.
121 )";
122 
123 static const char kOperationStrDunderDocstring[] =
124     R"(Gets the assembly form of the operation with default options.
125 
126 If more advanced control over the assembly formatting or I/O options is needed,
127 use the dedicated print or get_asm method, which supports keyword arguments to
128 customize behavior.
129 )";
130 
131 static const char kDumpDocstring[] =
132     R"(Dumps a debug representation of the object to stderr.)";
133 
134 static const char kAppendBlockDocstring[] =
135     R"(Appends a new block, with argument types as positional args.
136 
137 Returns:
138   The created block.
139 )";
140 
141 static const char kValueDunderStrDocstring[] =
142     R"(Returns the string form of the value.
143 
144 If the value is a block argument, this is the assembly form of its type and the
145 position in the argument list. If the value is an operation result, this is
146 equivalent to printing the operation that produced it.
147 )";
148 
149 //------------------------------------------------------------------------------
150 // Utilities.
151 //------------------------------------------------------------------------------
152 
153 /// Helper for creating an @classmethod.
154 template <class Func, typename... Args>
155 py::object classmethod(Func f, Args... args) {
156   py::object cf = py::cpp_function(f, args...);
157   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
158 }
159 
160 static py::object
161 createCustomDialectWrapper(const std::string &dialectNamespace,
162                            py::object dialectDescriptor) {
163   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
164   if (!dialectClass) {
165     // Use the base class.
166     return py::cast(PyDialect(std::move(dialectDescriptor)));
167   }
168 
169   // Create the custom implementation.
170   return (*dialectClass)(std::move(dialectDescriptor));
171 }
172 
173 static MlirStringRef toMlirStringRef(const std::string &s) {
174   return mlirStringRefCreate(s.data(), s.size());
175 }
176 
177 /// Wrapper for the global LLVM debugging flag.
178 struct PyGlobalDebugFlag {
179   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
180 
181   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
182 
183   static void bind(py::module &m) {
184     // Debug flags.
185     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
186         .def_property_static("flag", &PyGlobalDebugFlag::get,
187                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
188   }
189 };
190 
191 //------------------------------------------------------------------------------
192 // Collections.
193 //------------------------------------------------------------------------------
194 
195 namespace {
196 
197 class PyRegionIterator {
198 public:
199   PyRegionIterator(PyOperationRef operation)
200       : operation(std::move(operation)) {}
201 
202   PyRegionIterator &dunderIter() { return *this; }
203 
204   PyRegion dunderNext() {
205     operation->checkValid();
206     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
207       throw py::stop_iteration();
208     }
209     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
210     return PyRegion(operation, region);
211   }
212 
213   static void bind(py::module &m) {
214     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
215         .def("__iter__", &PyRegionIterator::dunderIter)
216         .def("__next__", &PyRegionIterator::dunderNext);
217   }
218 
219 private:
220   PyOperationRef operation;
221   int nextIndex = 0;
222 };
223 
224 /// Regions of an op are fixed length and indexed numerically so are represented
225 /// with a sequence-like container.
226 class PyRegionList {
227 public:
228   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
229 
230   intptr_t dunderLen() {
231     operation->checkValid();
232     return mlirOperationGetNumRegions(operation->get());
233   }
234 
235   PyRegion dunderGetItem(intptr_t index) {
236     // dunderLen checks validity.
237     if (index < 0 || index >= dunderLen()) {
238       throw SetPyError(PyExc_IndexError,
239                        "attempt to access out of bounds region");
240     }
241     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
242     return PyRegion(operation, region);
243   }
244 
245   static void bind(py::module &m) {
246     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
247         .def("__len__", &PyRegionList::dunderLen)
248         .def("__getitem__", &PyRegionList::dunderGetItem);
249   }
250 
251 private:
252   PyOperationRef operation;
253 };
254 
255 class PyBlockIterator {
256 public:
257   PyBlockIterator(PyOperationRef operation, MlirBlock next)
258       : operation(std::move(operation)), next(next) {}
259 
260   PyBlockIterator &dunderIter() { return *this; }
261 
262   PyBlock dunderNext() {
263     operation->checkValid();
264     if (mlirBlockIsNull(next)) {
265       throw py::stop_iteration();
266     }
267 
268     PyBlock returnBlock(operation, next);
269     next = mlirBlockGetNextInRegion(next);
270     return returnBlock;
271   }
272 
273   static void bind(py::module &m) {
274     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
275         .def("__iter__", &PyBlockIterator::dunderIter)
276         .def("__next__", &PyBlockIterator::dunderNext);
277   }
278 
279 private:
280   PyOperationRef operation;
281   MlirBlock next;
282 };
283 
284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
285 /// we present them as a more full-featured list-like container but optimize
286 /// it for forward iteration. Blocks are always owned by a region.
287 class PyBlockList {
288 public:
289   PyBlockList(PyOperationRef operation, MlirRegion region)
290       : operation(std::move(operation)), region(region) {}
291 
292   PyBlockIterator dunderIter() {
293     operation->checkValid();
294     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
295   }
296 
297   intptr_t dunderLen() {
298     operation->checkValid();
299     intptr_t count = 0;
300     MlirBlock block = mlirRegionGetFirstBlock(region);
301     while (!mlirBlockIsNull(block)) {
302       count += 1;
303       block = mlirBlockGetNextInRegion(block);
304     }
305     return count;
306   }
307 
308   PyBlock dunderGetItem(intptr_t index) {
309     operation->checkValid();
310     if (index < 0) {
311       throw SetPyError(PyExc_IndexError,
312                        "attempt to access out of bounds block");
313     }
314     MlirBlock block = mlirRegionGetFirstBlock(region);
315     while (!mlirBlockIsNull(block)) {
316       if (index == 0) {
317         return PyBlock(operation, block);
318       }
319       block = mlirBlockGetNextInRegion(block);
320       index -= 1;
321     }
322     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
323   }
324 
325   PyBlock appendBlock(const py::args &pyArgTypes) {
326     operation->checkValid();
327     llvm::SmallVector<MlirType, 4> argTypes;
328     argTypes.reserve(pyArgTypes.size());
329     for (auto &pyArg : pyArgTypes) {
330       argTypes.push_back(pyArg.cast<PyType &>());
331     }
332 
333     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
334     mlirRegionAppendOwnedBlock(region, block);
335     return PyBlock(operation, block);
336   }
337 
338   static void bind(py::module &m) {
339     py::class_<PyBlockList>(m, "BlockList", py::module_local())
340         .def("__getitem__", &PyBlockList::dunderGetItem)
341         .def("__iter__", &PyBlockList::dunderIter)
342         .def("__len__", &PyBlockList::dunderLen)
343         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
344   }
345 
346 private:
347   PyOperationRef operation;
348   MlirRegion region;
349 };
350 
351 class PyOperationIterator {
352 public:
353   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
354       : parentOperation(std::move(parentOperation)), next(next) {}
355 
356   PyOperationIterator &dunderIter() { return *this; }
357 
358   py::object dunderNext() {
359     parentOperation->checkValid();
360     if (mlirOperationIsNull(next)) {
361       throw py::stop_iteration();
362     }
363 
364     PyOperationRef returnOperation =
365         PyOperation::forOperation(parentOperation->getContext(), next);
366     next = mlirOperationGetNextInBlock(next);
367     return returnOperation->createOpView();
368   }
369 
370   static void bind(py::module &m) {
371     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
372         .def("__iter__", &PyOperationIterator::dunderIter)
373         .def("__next__", &PyOperationIterator::dunderNext);
374   }
375 
376 private:
377   PyOperationRef parentOperation;
378   MlirOperation next;
379 };
380 
381 /// Operations are exposed by the C-API as a forward-only linked list. In
382 /// Python, we present them as a more full-featured list-like container but
383 /// optimize it for forward iteration. Iterable operations are always owned
384 /// by a block.
385 class PyOperationList {
386 public:
387   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
388       : parentOperation(std::move(parentOperation)), block(block) {}
389 
390   PyOperationIterator dunderIter() {
391     parentOperation->checkValid();
392     return PyOperationIterator(parentOperation,
393                                mlirBlockGetFirstOperation(block));
394   }
395 
396   intptr_t dunderLen() {
397     parentOperation->checkValid();
398     intptr_t count = 0;
399     MlirOperation childOp = mlirBlockGetFirstOperation(block);
400     while (!mlirOperationIsNull(childOp)) {
401       count += 1;
402       childOp = mlirOperationGetNextInBlock(childOp);
403     }
404     return count;
405   }
406 
407   py::object dunderGetItem(intptr_t index) {
408     parentOperation->checkValid();
409     if (index < 0) {
410       throw SetPyError(PyExc_IndexError,
411                        "attempt to access out of bounds operation");
412     }
413     MlirOperation childOp = mlirBlockGetFirstOperation(block);
414     while (!mlirOperationIsNull(childOp)) {
415       if (index == 0) {
416         return PyOperation::forOperation(parentOperation->getContext(), childOp)
417             ->createOpView();
418       }
419       childOp = mlirOperationGetNextInBlock(childOp);
420       index -= 1;
421     }
422     throw SetPyError(PyExc_IndexError,
423                      "attempt to access out of bounds operation");
424   }
425 
426   static void bind(py::module &m) {
427     py::class_<PyOperationList>(m, "OperationList", py::module_local())
428         .def("__getitem__", &PyOperationList::dunderGetItem)
429         .def("__iter__", &PyOperationList::dunderIter)
430         .def("__len__", &PyOperationList::dunderLen);
431   }
432 
433 private:
434   PyOperationRef parentOperation;
435   MlirBlock block;
436 };
437 
438 } // namespace
439 
440 //------------------------------------------------------------------------------
441 // PyMlirContext
442 //------------------------------------------------------------------------------
443 
444 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
445   py::gil_scoped_acquire acquire;
446   auto &liveContexts = getLiveContexts();
447   liveContexts[context.ptr] = this;
448 }
449 
450 PyMlirContext::~PyMlirContext() {
451   // Note that the only public way to construct an instance is via the
452   // forContext method, which always puts the associated handle into
453   // liveContexts.
454   py::gil_scoped_acquire acquire;
455   getLiveContexts().erase(context.ptr);
456   mlirContextDestroy(context);
457 }
458 
459 py::object PyMlirContext::getCapsule() {
460   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
461 }
462 
463 py::object PyMlirContext::createFromCapsule(py::object capsule) {
464   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
465   if (mlirContextIsNull(rawContext))
466     throw py::error_already_set();
467   return forContext(rawContext).releaseObject();
468 }
469 
470 PyMlirContext *PyMlirContext::createNewContextForInit() {
471   MlirContext context = mlirContextCreate();
472   mlirRegisterAllDialects(context);
473   return new PyMlirContext(context);
474 }
475 
476 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
477   py::gil_scoped_acquire acquire;
478   auto &liveContexts = getLiveContexts();
479   auto it = liveContexts.find(context.ptr);
480   if (it == liveContexts.end()) {
481     // Create.
482     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
483     py::object pyRef = py::cast(unownedContextWrapper);
484     assert(pyRef && "cast to py::object failed");
485     liveContexts[context.ptr] = unownedContextWrapper;
486     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
487   }
488   // Use existing.
489   py::object pyRef = py::cast(it->second);
490   return PyMlirContextRef(it->second, std::move(pyRef));
491 }
492 
493 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
494   static LiveContextMap liveContexts;
495   return liveContexts;
496 }
497 
498 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
499 
500 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
501 
502 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
503 
504 pybind11::object PyMlirContext::contextEnter() {
505   return PyThreadContextEntry::pushContext(*this);
506 }
507 
508 void PyMlirContext::contextExit(const pybind11::object &excType,
509                                 const pybind11::object &excVal,
510                                 const pybind11::object &excTb) {
511   PyThreadContextEntry::popContext(*this);
512 }
513 
514 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
515   // Note that ownership is transferred to the delete callback below by way of
516   // an explicit inc_ref (borrow).
517   PyDiagnosticHandler *pyHandler =
518       new PyDiagnosticHandler(get(), std::move(callback));
519   py::object pyHandlerObject =
520       py::cast(pyHandler, py::return_value_policy::take_ownership);
521   pyHandlerObject.inc_ref();
522 
523   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
524   // guaranteed to be known to pybind.
525   auto handlerCallback =
526       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
527     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
528     py::object pyDiagnosticObject =
529         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
530 
531     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
532     bool result = false;
533     {
534       // Since this can be called from arbitrary C++ contexts, always get the
535       // gil.
536       py::gil_scoped_acquire gil;
537       try {
538         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
539       } catch (std::exception &e) {
540         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
541                 e.what());
542         pyHandler->hadError = true;
543       }
544     }
545 
546     pyDiagnostic->invalidate();
547     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
548   };
549   auto deleteCallback = +[](void *userData) {
550     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
551     assert(pyHandler->registeredID && "handler is not registered");
552     pyHandler->registeredID.reset();
553 
554     // Decrement reference, balancing the inc_ref() above.
555     py::object pyHandlerObject =
556         py::cast(pyHandler, py::return_value_policy::reference);
557     pyHandlerObject.dec_ref();
558   };
559 
560   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
561       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
562   return pyHandlerObject;
563 }
564 
565 PyMlirContext &DefaultingPyMlirContext::resolve() {
566   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
567   if (!context) {
568     throw SetPyError(
569         PyExc_RuntimeError,
570         "An MLIR function requires a Context but none was provided in the call "
571         "or from the surrounding environment. Either pass to the function with "
572         "a 'context=' argument or establish a default using 'with Context():'");
573   }
574   return *context;
575 }
576 
577 //------------------------------------------------------------------------------
578 // PyThreadContextEntry management
579 //------------------------------------------------------------------------------
580 
581 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
582   static thread_local std::vector<PyThreadContextEntry> stack;
583   return stack;
584 }
585 
586 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
587   auto &stack = getStack();
588   if (stack.empty())
589     return nullptr;
590   return &stack.back();
591 }
592 
593 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
594                                 py::object insertionPoint,
595                                 py::object location) {
596   auto &stack = getStack();
597   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
598                      std::move(location));
599   // If the new stack has more than one entry and the context of the new top
600   // entry matches the previous, copy the insertionPoint and location from the
601   // previous entry if missing from the new top entry.
602   if (stack.size() > 1) {
603     auto &prev = *(stack.rbegin() + 1);
604     auto &current = stack.back();
605     if (current.context.is(prev.context)) {
606       // Default non-context objects from the previous entry.
607       if (!current.insertionPoint)
608         current.insertionPoint = prev.insertionPoint;
609       if (!current.location)
610         current.location = prev.location;
611     }
612   }
613 }
614 
615 PyMlirContext *PyThreadContextEntry::getContext() {
616   if (!context)
617     return nullptr;
618   return py::cast<PyMlirContext *>(context);
619 }
620 
621 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
622   if (!insertionPoint)
623     return nullptr;
624   return py::cast<PyInsertionPoint *>(insertionPoint);
625 }
626 
627 PyLocation *PyThreadContextEntry::getLocation() {
628   if (!location)
629     return nullptr;
630   return py::cast<PyLocation *>(location);
631 }
632 
633 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
634   auto *tos = getTopOfStack();
635   return tos ? tos->getContext() : nullptr;
636 }
637 
638 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
639   auto *tos = getTopOfStack();
640   return tos ? tos->getInsertionPoint() : nullptr;
641 }
642 
643 PyLocation *PyThreadContextEntry::getDefaultLocation() {
644   auto *tos = getTopOfStack();
645   return tos ? tos->getLocation() : nullptr;
646 }
647 
648 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
649   py::object contextObj = py::cast(context);
650   push(FrameKind::Context, /*context=*/contextObj,
651        /*insertionPoint=*/py::object(),
652        /*location=*/py::object());
653   return contextObj;
654 }
655 
656 void PyThreadContextEntry::popContext(PyMlirContext &context) {
657   auto &stack = getStack();
658   if (stack.empty())
659     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
660   auto &tos = stack.back();
661   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
662     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
663   stack.pop_back();
664 }
665 
666 py::object
667 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
668   py::object contextObj =
669       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
670   py::object insertionPointObj = py::cast(insertionPoint);
671   push(FrameKind::InsertionPoint,
672        /*context=*/contextObj,
673        /*insertionPoint=*/insertionPointObj,
674        /*location=*/py::object());
675   return insertionPointObj;
676 }
677 
678 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
679   auto &stack = getStack();
680   if (stack.empty())
681     throw SetPyError(PyExc_RuntimeError,
682                      "Unbalanced InsertionPoint enter/exit");
683   auto &tos = stack.back();
684   if (tos.frameKind != FrameKind::InsertionPoint &&
685       tos.getInsertionPoint() != &insertionPoint)
686     throw SetPyError(PyExc_RuntimeError,
687                      "Unbalanced InsertionPoint enter/exit");
688   stack.pop_back();
689 }
690 
691 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
692   py::object contextObj = location.getContext().getObject();
693   py::object locationObj = py::cast(location);
694   push(FrameKind::Location, /*context=*/contextObj,
695        /*insertionPoint=*/py::object(),
696        /*location=*/locationObj);
697   return locationObj;
698 }
699 
700 void PyThreadContextEntry::popLocation(PyLocation &location) {
701   auto &stack = getStack();
702   if (stack.empty())
703     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
704   auto &tos = stack.back();
705   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
706     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
707   stack.pop_back();
708 }
709 
710 //------------------------------------------------------------------------------
711 // PyDiagnostic*
712 //------------------------------------------------------------------------------
713 
714 void PyDiagnostic::invalidate() {
715   valid = false;
716   if (materializedNotes) {
717     for (auto &noteObject : *materializedNotes) {
718       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
719       note->invalidate();
720     }
721   }
722 }
723 
724 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
725                                          py::object callback)
726     : context(context), callback(std::move(callback)) {}
727 
728 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
729 
730 void PyDiagnosticHandler::detach() {
731   if (!registeredID)
732     return;
733   MlirDiagnosticHandlerID localID = *registeredID;
734   mlirContextDetachDiagnosticHandler(context, localID);
735   assert(!registeredID && "should have unregistered");
736   // Not strictly necessary but keeps stale pointers from being around to cause
737   // issues.
738   context = {nullptr};
739 }
740 
741 void PyDiagnostic::checkValid() {
742   if (!valid) {
743     throw std::invalid_argument(
744         "Diagnostic is invalid (used outside of callback)");
745   }
746 }
747 
748 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
749   checkValid();
750   return mlirDiagnosticGetSeverity(diagnostic);
751 }
752 
753 PyLocation PyDiagnostic::getLocation() {
754   checkValid();
755   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
756   MlirContext context = mlirLocationGetContext(loc);
757   return PyLocation(PyMlirContext::forContext(context), loc);
758 }
759 
760 py::str PyDiagnostic::getMessage() {
761   checkValid();
762   py::object fileObject = py::module::import("io").attr("StringIO")();
763   PyFileAccumulator accum(fileObject, /*binary=*/false);
764   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
765   return fileObject.attr("getvalue")();
766 }
767 
768 py::tuple PyDiagnostic::getNotes() {
769   checkValid();
770   if (materializedNotes)
771     return *materializedNotes;
772   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
773   materializedNotes = py::tuple(numNotes);
774   for (intptr_t i = 0; i < numNotes; ++i) {
775     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
776     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
777     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
778   }
779   return *materializedNotes;
780 }
781 
782 //------------------------------------------------------------------------------
783 // PyDialect, PyDialectDescriptor, PyDialects
784 //------------------------------------------------------------------------------
785 
786 MlirDialect PyDialects::getDialectForKey(const std::string &key,
787                                          bool attrError) {
788   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
789                                                     {key.data(), key.size()});
790   if (mlirDialectIsNull(dialect)) {
791     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
792                      Twine("Dialect '") + key + "' not found");
793   }
794   return dialect;
795 }
796 
797 //------------------------------------------------------------------------------
798 // PyLocation
799 //------------------------------------------------------------------------------
800 
801 py::object PyLocation::getCapsule() {
802   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
803 }
804 
805 PyLocation PyLocation::createFromCapsule(py::object capsule) {
806   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
807   if (mlirLocationIsNull(rawLoc))
808     throw py::error_already_set();
809   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
810                     rawLoc);
811 }
812 
813 py::object PyLocation::contextEnter() {
814   return PyThreadContextEntry::pushLocation(*this);
815 }
816 
817 void PyLocation::contextExit(const pybind11::object &excType,
818                              const pybind11::object &excVal,
819                              const pybind11::object &excTb) {
820   PyThreadContextEntry::popLocation(*this);
821 }
822 
823 PyLocation &DefaultingPyLocation::resolve() {
824   auto *location = PyThreadContextEntry::getDefaultLocation();
825   if (!location) {
826     throw SetPyError(
827         PyExc_RuntimeError,
828         "An MLIR function requires a Location but none was provided in the "
829         "call or from the surrounding environment. Either pass to the function "
830         "with a 'loc=' argument or establish a default using 'with loc:'");
831   }
832   return *location;
833 }
834 
835 //------------------------------------------------------------------------------
836 // PyModule
837 //------------------------------------------------------------------------------
838 
839 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
840     : BaseContextObject(std::move(contextRef)), module(module) {}
841 
842 PyModule::~PyModule() {
843   py::gil_scoped_acquire acquire;
844   auto &liveModules = getContext()->liveModules;
845   assert(liveModules.count(module.ptr) == 1 &&
846          "destroying module not in live map");
847   liveModules.erase(module.ptr);
848   mlirModuleDestroy(module);
849 }
850 
851 PyModuleRef PyModule::forModule(MlirModule module) {
852   MlirContext context = mlirModuleGetContext(module);
853   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
854 
855   py::gil_scoped_acquire acquire;
856   auto &liveModules = contextRef->liveModules;
857   auto it = liveModules.find(module.ptr);
858   if (it == liveModules.end()) {
859     // Create.
860     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
861     // Note that the default return value policy on cast is automatic_reference,
862     // which does not take ownership (delete will not be called).
863     // Just be explicit.
864     py::object pyRef =
865         py::cast(unownedModule, py::return_value_policy::take_ownership);
866     unownedModule->handle = pyRef;
867     liveModules[module.ptr] =
868         std::make_pair(unownedModule->handle, unownedModule);
869     return PyModuleRef(unownedModule, std::move(pyRef));
870   }
871   // Use existing.
872   PyModule *existing = it->second.second;
873   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
874   return PyModuleRef(existing, std::move(pyRef));
875 }
876 
877 py::object PyModule::createFromCapsule(py::object capsule) {
878   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
879   if (mlirModuleIsNull(rawModule))
880     throw py::error_already_set();
881   return forModule(rawModule).releaseObject();
882 }
883 
884 py::object PyModule::getCapsule() {
885   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
886 }
887 
888 //------------------------------------------------------------------------------
889 // PyOperation
890 //------------------------------------------------------------------------------
891 
892 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
893     : BaseContextObject(std::move(contextRef)), operation(operation) {}
894 
895 PyOperation::~PyOperation() {
896   // If the operation has already been invalidated there is nothing to do.
897   if (!valid)
898     return;
899   auto &liveOperations = getContext()->liveOperations;
900   assert(liveOperations.count(operation.ptr) == 1 &&
901          "destroying operation not in live map");
902   liveOperations.erase(operation.ptr);
903   if (!isAttached()) {
904     mlirOperationDestroy(operation);
905   }
906 }
907 
908 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
909                                            MlirOperation operation,
910                                            py::object parentKeepAlive) {
911   auto &liveOperations = contextRef->liveOperations;
912   // Create.
913   PyOperation *unownedOperation =
914       new PyOperation(std::move(contextRef), operation);
915   // Note that the default return value policy on cast is automatic_reference,
916   // which does not take ownership (delete will not be called).
917   // Just be explicit.
918   py::object pyRef =
919       py::cast(unownedOperation, py::return_value_policy::take_ownership);
920   unownedOperation->handle = pyRef;
921   if (parentKeepAlive) {
922     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
923   }
924   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
925   return PyOperationRef(unownedOperation, std::move(pyRef));
926 }
927 
928 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
929                                          MlirOperation operation,
930                                          py::object parentKeepAlive) {
931   auto &liveOperations = contextRef->liveOperations;
932   auto it = liveOperations.find(operation.ptr);
933   if (it == liveOperations.end()) {
934     // Create.
935     return createInstance(std::move(contextRef), operation,
936                           std::move(parentKeepAlive));
937   }
938   // Use existing.
939   PyOperation *existing = it->second.second;
940   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
941   return PyOperationRef(existing, std::move(pyRef));
942 }
943 
944 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
945                                            MlirOperation operation,
946                                            py::object parentKeepAlive) {
947   auto &liveOperations = contextRef->liveOperations;
948   assert(liveOperations.count(operation.ptr) == 0 &&
949          "cannot create detached operation that already exists");
950   (void)liveOperations;
951 
952   PyOperationRef created = createInstance(std::move(contextRef), operation,
953                                           std::move(parentKeepAlive));
954   created->attached = false;
955   return created;
956 }
957 
958 void PyOperation::checkValid() const {
959   if (!valid) {
960     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
961   }
962 }
963 
964 void PyOperationBase::print(py::object fileObject, bool binary,
965                             llvm::Optional<int64_t> largeElementsLimit,
966                             bool enableDebugInfo, bool prettyDebugInfo,
967                             bool printGenericOpForm, bool useLocalScope,
968                             bool assumeVerified) {
969   PyOperation &operation = getOperation();
970   operation.checkValid();
971   if (fileObject.is_none())
972     fileObject = py::module::import("sys").attr("stdout");
973 
974   if (!assumeVerified && !printGenericOpForm &&
975       !mlirOperationVerify(operation)) {
976     std::string message("// Verification failed, printing generic form\n");
977     if (binary) {
978       fileObject.attr("write")(py::bytes(message));
979     } else {
980       fileObject.attr("write")(py::str(message));
981     }
982     printGenericOpForm = true;
983   }
984 
985   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
986   if (largeElementsLimit)
987     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
988   if (enableDebugInfo)
989     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
990   if (printGenericOpForm)
991     mlirOpPrintingFlagsPrintGenericOpForm(flags);
992 
993   PyFileAccumulator accum(fileObject, binary);
994   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
995                               accum.getUserData());
996   mlirOpPrintingFlagsDestroy(flags);
997 }
998 
999 py::object PyOperationBase::getAsm(bool binary,
1000                                    llvm::Optional<int64_t> largeElementsLimit,
1001                                    bool enableDebugInfo, bool prettyDebugInfo,
1002                                    bool printGenericOpForm, bool useLocalScope,
1003                                    bool assumeVerified) {
1004   py::object fileObject;
1005   if (binary) {
1006     fileObject = py::module::import("io").attr("BytesIO")();
1007   } else {
1008     fileObject = py::module::import("io").attr("StringIO")();
1009   }
1010   print(fileObject, /*binary=*/binary,
1011         /*largeElementsLimit=*/largeElementsLimit,
1012         /*enableDebugInfo=*/enableDebugInfo,
1013         /*prettyDebugInfo=*/prettyDebugInfo,
1014         /*printGenericOpForm=*/printGenericOpForm,
1015         /*useLocalScope=*/useLocalScope,
1016         /*assumeVerified=*/assumeVerified);
1017 
1018   return fileObject.attr("getvalue")();
1019 }
1020 
1021 void PyOperationBase::moveAfter(PyOperationBase &other) {
1022   PyOperation &operation = getOperation();
1023   PyOperation &otherOp = other.getOperation();
1024   operation.checkValid();
1025   otherOp.checkValid();
1026   mlirOperationMoveAfter(operation, otherOp);
1027   operation.parentKeepAlive = otherOp.parentKeepAlive;
1028 }
1029 
1030 void PyOperationBase::moveBefore(PyOperationBase &other) {
1031   PyOperation &operation = getOperation();
1032   PyOperation &otherOp = other.getOperation();
1033   operation.checkValid();
1034   otherOp.checkValid();
1035   mlirOperationMoveBefore(operation, otherOp);
1036   operation.parentKeepAlive = otherOp.parentKeepAlive;
1037 }
1038 
1039 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1040   checkValid();
1041   if (!isAttached())
1042     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1043   MlirOperation operation = mlirOperationGetParentOperation(get());
1044   if (mlirOperationIsNull(operation))
1045     return {};
1046   return PyOperation::forOperation(getContext(), operation);
1047 }
1048 
1049 PyBlock PyOperation::getBlock() {
1050   checkValid();
1051   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1052   MlirBlock block = mlirOperationGetBlock(get());
1053   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1054   assert(parentOperation && "Operation has no parent");
1055   return PyBlock{std::move(*parentOperation), block};
1056 }
1057 
1058 py::object PyOperation::getCapsule() {
1059   checkValid();
1060   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1061 }
1062 
1063 py::object PyOperation::createFromCapsule(py::object capsule) {
1064   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1065   if (mlirOperationIsNull(rawOperation))
1066     throw py::error_already_set();
1067   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1068   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1069       .releaseObject();
1070 }
1071 
1072 py::object PyOperation::create(
1073     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1074     llvm::Optional<std::vector<PyValue *>> operands,
1075     llvm::Optional<py::dict> attributes,
1076     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1077     DefaultingPyLocation location, const py::object &maybeIp) {
1078   llvm::SmallVector<MlirValue, 4> mlirOperands;
1079   llvm::SmallVector<MlirType, 4> mlirResults;
1080   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1081   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1082 
1083   // General parameter validation.
1084   if (regions < 0)
1085     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1086 
1087   // Unpack/validate operands.
1088   if (operands) {
1089     mlirOperands.reserve(operands->size());
1090     for (PyValue *operand : *operands) {
1091       if (!operand)
1092         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1093       mlirOperands.push_back(operand->get());
1094     }
1095   }
1096 
1097   // Unpack/validate results.
1098   if (results) {
1099     mlirResults.reserve(results->size());
1100     for (PyType *result : *results) {
1101       // TODO: Verify result type originate from the same context.
1102       if (!result)
1103         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1104       mlirResults.push_back(*result);
1105     }
1106   }
1107   // Unpack/validate attributes.
1108   if (attributes) {
1109     mlirAttributes.reserve(attributes->size());
1110     for (auto &it : *attributes) {
1111       std::string key;
1112       try {
1113         key = it.first.cast<std::string>();
1114       } catch (py::cast_error &err) {
1115         std::string msg = "Invalid attribute key (not a string) when "
1116                           "attempting to create the operation \"" +
1117                           name + "\" (" + err.what() + ")";
1118         throw py::cast_error(msg);
1119       }
1120       try {
1121         auto &attribute = it.second.cast<PyAttribute &>();
1122         // TODO: Verify attribute originates from the same context.
1123         mlirAttributes.emplace_back(std::move(key), attribute);
1124       } catch (py::reference_cast_error &) {
1125         // This exception seems thrown when the value is "None".
1126         std::string msg =
1127             "Found an invalid (`None`?) attribute value for the key \"" + key +
1128             "\" when attempting to create the operation \"" + name + "\"";
1129         throw py::cast_error(msg);
1130       } catch (py::cast_error &err) {
1131         std::string msg = "Invalid attribute value for the key \"" + key +
1132                           "\" when attempting to create the operation \"" +
1133                           name + "\" (" + err.what() + ")";
1134         throw py::cast_error(msg);
1135       }
1136     }
1137   }
1138   // Unpack/validate successors.
1139   if (successors) {
1140     mlirSuccessors.reserve(successors->size());
1141     for (auto *successor : *successors) {
1142       // TODO: Verify successor originate from the same context.
1143       if (!successor)
1144         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1145       mlirSuccessors.push_back(successor->get());
1146     }
1147   }
1148 
1149   // Apply unpacked/validated to the operation state. Beyond this
1150   // point, exceptions cannot be thrown or else the state will leak.
1151   MlirOperationState state =
1152       mlirOperationStateGet(toMlirStringRef(name), location);
1153   if (!mlirOperands.empty())
1154     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1155                                   mlirOperands.data());
1156   if (!mlirResults.empty())
1157     mlirOperationStateAddResults(&state, mlirResults.size(),
1158                                  mlirResults.data());
1159   if (!mlirAttributes.empty()) {
1160     // Note that the attribute names directly reference bytes in
1161     // mlirAttributes, so that vector must not be changed from here
1162     // on.
1163     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1164     mlirNamedAttributes.reserve(mlirAttributes.size());
1165     for (auto &it : mlirAttributes)
1166       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1167           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1168                             toMlirStringRef(it.first)),
1169           it.second));
1170     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1171                                     mlirNamedAttributes.data());
1172   }
1173   if (!mlirSuccessors.empty())
1174     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1175                                     mlirSuccessors.data());
1176   if (regions) {
1177     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1178     mlirRegions.resize(regions);
1179     for (int i = 0; i < regions; ++i)
1180       mlirRegions[i] = mlirRegionCreate();
1181     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1182                                       mlirRegions.data());
1183   }
1184 
1185   // Construct the operation.
1186   MlirOperation operation = mlirOperationCreate(&state);
1187   PyOperationRef created =
1188       PyOperation::createDetached(location->getContext(), operation);
1189 
1190   // InsertPoint active?
1191   if (!maybeIp.is(py::cast(false))) {
1192     PyInsertionPoint *ip;
1193     if (maybeIp.is_none()) {
1194       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1195     } else {
1196       ip = py::cast<PyInsertionPoint *>(maybeIp);
1197     }
1198     if (ip)
1199       ip->insert(*created.get());
1200   }
1201 
1202   return created->createOpView();
1203 }
1204 
1205 py::object PyOperation::createOpView() {
1206   checkValid();
1207   MlirIdentifier ident = mlirOperationGetName(get());
1208   MlirStringRef identStr = mlirIdentifierStr(ident);
1209   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1210       StringRef(identStr.data, identStr.length));
1211   if (opViewClass)
1212     return (*opViewClass)(getRef().getObject());
1213   return py::cast(PyOpView(getRef().getObject()));
1214 }
1215 
1216 void PyOperation::erase() {
1217   checkValid();
1218   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1219   // Python reference to a child operation is live. All children should also
1220   // have their `valid` bit set to false.
1221   auto &liveOperations = getContext()->liveOperations;
1222   if (liveOperations.count(operation.ptr))
1223     liveOperations.erase(operation.ptr);
1224   mlirOperationDestroy(operation);
1225   valid = false;
1226 }
1227 
1228 //------------------------------------------------------------------------------
1229 // PyOpView
1230 //------------------------------------------------------------------------------
1231 
1232 py::object PyOpView::buildGeneric(
1233     const py::object &cls, py::list resultTypeList, py::list operandList,
1234     llvm::Optional<py::dict> attributes,
1235     llvm::Optional<std::vector<PyBlock *>> successors,
1236     llvm::Optional<int> regions, DefaultingPyLocation location,
1237     const py::object &maybeIp) {
1238   PyMlirContextRef context = location->getContext();
1239   // Class level operation construction metadata.
1240   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1241   // Operand and result segment specs are either none, which does no
1242   // variadic unpacking, or a list of ints with segment sizes, where each
1243   // element is either a positive number (typically 1 for a scalar) or -1 to
1244   // indicate that it is derived from the length of the same-indexed operand
1245   // or result (implying that it is a list at that position).
1246   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1247   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1248 
1249   std::vector<uint32_t> operandSegmentLengths;
1250   std::vector<uint32_t> resultSegmentLengths;
1251 
1252   // Validate/determine region count.
1253   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1254   int opMinRegionCount = std::get<0>(opRegionSpec);
1255   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1256   if (!regions) {
1257     regions = opMinRegionCount;
1258   }
1259   if (*regions < opMinRegionCount) {
1260     throw py::value_error(
1261         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1262          llvm::Twine(opMinRegionCount) +
1263          " regions but was built with regions=" + llvm::Twine(*regions))
1264             .str());
1265   }
1266   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1267     throw py::value_error(
1268         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1269          llvm::Twine(opMinRegionCount) +
1270          " regions but was built with regions=" + llvm::Twine(*regions))
1271             .str());
1272   }
1273 
1274   // Unpack results.
1275   std::vector<PyType *> resultTypes;
1276   resultTypes.reserve(resultTypeList.size());
1277   if (resultSegmentSpecObj.is_none()) {
1278     // Non-variadic result unpacking.
1279     for (const auto &it : llvm::enumerate(resultTypeList)) {
1280       try {
1281         resultTypes.push_back(py::cast<PyType *>(it.value()));
1282         if (!resultTypes.back())
1283           throw py::cast_error();
1284       } catch (py::cast_error &err) {
1285         throw py::value_error((llvm::Twine("Result ") +
1286                                llvm::Twine(it.index()) + " of operation \"" +
1287                                name + "\" must be a Type (" + err.what() + ")")
1288                                   .str());
1289       }
1290     }
1291   } else {
1292     // Sized result unpacking.
1293     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1294     if (resultSegmentSpec.size() != resultTypeList.size()) {
1295       throw py::value_error((llvm::Twine("Operation \"") + name +
1296                              "\" requires " +
1297                              llvm::Twine(resultSegmentSpec.size()) +
1298                              " result segments but was provided " +
1299                              llvm::Twine(resultTypeList.size()))
1300                                 .str());
1301     }
1302     resultSegmentLengths.reserve(resultTypeList.size());
1303     for (const auto &it :
1304          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1305       int segmentSpec = std::get<1>(it.value());
1306       if (segmentSpec == 1 || segmentSpec == 0) {
1307         // Unpack unary element.
1308         try {
1309           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1310           if (resultType) {
1311             resultTypes.push_back(resultType);
1312             resultSegmentLengths.push_back(1);
1313           } else if (segmentSpec == 0) {
1314             // Allowed to be optional.
1315             resultSegmentLengths.push_back(0);
1316           } else {
1317             throw py::cast_error("was None and result is not optional");
1318           }
1319         } catch (py::cast_error &err) {
1320           throw py::value_error((llvm::Twine("Result ") +
1321                                  llvm::Twine(it.index()) + " of operation \"" +
1322                                  name + "\" must be a Type (" + err.what() +
1323                                  ")")
1324                                     .str());
1325         }
1326       } else if (segmentSpec == -1) {
1327         // Unpack sequence by appending.
1328         try {
1329           if (std::get<0>(it.value()).is_none()) {
1330             // Treat it as an empty list.
1331             resultSegmentLengths.push_back(0);
1332           } else {
1333             // Unpack the list.
1334             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1335             for (py::object segmentItem : segment) {
1336               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1337               if (!resultTypes.back()) {
1338                 throw py::cast_error("contained a None item");
1339               }
1340             }
1341             resultSegmentLengths.push_back(segment.size());
1342           }
1343         } catch (std::exception &err) {
1344           // NOTE: Sloppy to be using a catch-all here, but there are at least
1345           // three different unrelated exceptions that can be thrown in the
1346           // above "casts". Just keep the scope above small and catch them all.
1347           throw py::value_error((llvm::Twine("Result ") +
1348                                  llvm::Twine(it.index()) + " of operation \"" +
1349                                  name + "\" must be a Sequence of Types (" +
1350                                  err.what() + ")")
1351                                     .str());
1352         }
1353       } else {
1354         throw py::value_error("Unexpected segment spec");
1355       }
1356     }
1357   }
1358 
1359   // Unpack operands.
1360   std::vector<PyValue *> operands;
1361   operands.reserve(operands.size());
1362   if (operandSegmentSpecObj.is_none()) {
1363     // Non-sized operand unpacking.
1364     for (const auto &it : llvm::enumerate(operandList)) {
1365       try {
1366         operands.push_back(py::cast<PyValue *>(it.value()));
1367         if (!operands.back())
1368           throw py::cast_error();
1369       } catch (py::cast_error &err) {
1370         throw py::value_error((llvm::Twine("Operand ") +
1371                                llvm::Twine(it.index()) + " of operation \"" +
1372                                name + "\" must be a Value (" + err.what() + ")")
1373                                   .str());
1374       }
1375     }
1376   } else {
1377     // Sized operand unpacking.
1378     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1379     if (operandSegmentSpec.size() != operandList.size()) {
1380       throw py::value_error((llvm::Twine("Operation \"") + name +
1381                              "\" requires " +
1382                              llvm::Twine(operandSegmentSpec.size()) +
1383                              "operand segments but was provided " +
1384                              llvm::Twine(operandList.size()))
1385                                 .str());
1386     }
1387     operandSegmentLengths.reserve(operandList.size());
1388     for (const auto &it :
1389          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1390       int segmentSpec = std::get<1>(it.value());
1391       if (segmentSpec == 1 || segmentSpec == 0) {
1392         // Unpack unary element.
1393         try {
1394           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1395           if (operandValue) {
1396             operands.push_back(operandValue);
1397             operandSegmentLengths.push_back(1);
1398           } else if (segmentSpec == 0) {
1399             // Allowed to be optional.
1400             operandSegmentLengths.push_back(0);
1401           } else {
1402             throw py::cast_error("was None and operand is not optional");
1403           }
1404         } catch (py::cast_error &err) {
1405           throw py::value_error((llvm::Twine("Operand ") +
1406                                  llvm::Twine(it.index()) + " of operation \"" +
1407                                  name + "\" must be a Value (" + err.what() +
1408                                  ")")
1409                                     .str());
1410         }
1411       } else if (segmentSpec == -1) {
1412         // Unpack sequence by appending.
1413         try {
1414           if (std::get<0>(it.value()).is_none()) {
1415             // Treat it as an empty list.
1416             operandSegmentLengths.push_back(0);
1417           } else {
1418             // Unpack the list.
1419             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1420             for (py::object segmentItem : segment) {
1421               operands.push_back(py::cast<PyValue *>(segmentItem));
1422               if (!operands.back()) {
1423                 throw py::cast_error("contained a None item");
1424               }
1425             }
1426             operandSegmentLengths.push_back(segment.size());
1427           }
1428         } catch (std::exception &err) {
1429           // NOTE: Sloppy to be using a catch-all here, but there are at least
1430           // three different unrelated exceptions that can be thrown in the
1431           // above "casts". Just keep the scope above small and catch them all.
1432           throw py::value_error((llvm::Twine("Operand ") +
1433                                  llvm::Twine(it.index()) + " of operation \"" +
1434                                  name + "\" must be a Sequence of Values (" +
1435                                  err.what() + ")")
1436                                     .str());
1437         }
1438       } else {
1439         throw py::value_error("Unexpected segment spec");
1440       }
1441     }
1442   }
1443 
1444   // Merge operand/result segment lengths into attributes if needed.
1445   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1446     // Dup.
1447     if (attributes) {
1448       attributes = py::dict(*attributes);
1449     } else {
1450       attributes = py::dict();
1451     }
1452     if (attributes->contains("result_segment_sizes") ||
1453         attributes->contains("operand_segment_sizes")) {
1454       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1455                             "'operand_segment_sizes' attribute is unsupported. "
1456                             "Use Operation.create for such low-level access.");
1457     }
1458 
1459     // Add result_segment_sizes attribute.
1460     if (!resultSegmentLengths.empty()) {
1461       int64_t size = resultSegmentLengths.size();
1462       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1463           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1464           resultSegmentLengths.size(), resultSegmentLengths.data());
1465       (*attributes)["result_segment_sizes"] =
1466           PyAttribute(context, segmentLengthAttr);
1467     }
1468 
1469     // Add operand_segment_sizes attribute.
1470     if (!operandSegmentLengths.empty()) {
1471       int64_t size = operandSegmentLengths.size();
1472       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1473           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1474           operandSegmentLengths.size(), operandSegmentLengths.data());
1475       (*attributes)["operand_segment_sizes"] =
1476           PyAttribute(context, segmentLengthAttr);
1477     }
1478   }
1479 
1480   // Delegate to create.
1481   return PyOperation::create(name,
1482                              /*results=*/std::move(resultTypes),
1483                              /*operands=*/std::move(operands),
1484                              /*attributes=*/std::move(attributes),
1485                              /*successors=*/std::move(successors),
1486                              /*regions=*/*regions, location, maybeIp);
1487 }
1488 
1489 PyOpView::PyOpView(const py::object &operationObject)
1490     // Casting through the PyOperationBase base-class and then back to the
1491     // Operation lets us accept any PyOperationBase subclass.
1492     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1493       operationObject(operation.getRef().getObject()) {}
1494 
1495 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1496   // This is... a little gross. The typical pattern is to have a pure python
1497   // class that extends OpView like:
1498   //   class AddFOp(_cext.ir.OpView):
1499   //     def __init__(self, loc, lhs, rhs):
1500   //       operation = loc.context.create_operation(
1501   //           "addf", lhs, rhs, results=[lhs.type])
1502   //       super().__init__(operation)
1503   //
1504   // I.e. The goal of the user facing type is to provide a nice constructor
1505   // that has complete freedom for the op under construction. This is at odds
1506   // with our other desire to sometimes create this object by just passing an
1507   // operation (to initialize the base class). We could do *arg and **kwargs
1508   // munging to try to make it work, but instead, we synthesize a new class
1509   // on the fly which extends this user class (AddFOp in this example) and
1510   // *give it* the base class's __init__ method, thus bypassing the
1511   // intermediate subclass's __init__ method entirely. While slightly,
1512   // underhanded, this is safe/legal because the type hierarchy has not changed
1513   // (we just added a new leaf) and we aren't mucking around with __new__.
1514   // Typically, this new class will be stored on the original as "_Raw" and will
1515   // be used for casts and other things that need a variant of the class that
1516   // is initialized purely from an operation.
1517   py::object parentMetaclass =
1518       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1519   py::dict attributes;
1520   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1521   // now.
1522   //   auto opViewType = py::type::of<PyOpView>();
1523   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1524   attributes["__init__"] = opViewType.attr("__init__");
1525   py::str origName = userClass.attr("__name__");
1526   py::str newName = py::str("_") + origName;
1527   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1528 }
1529 
1530 //------------------------------------------------------------------------------
1531 // PyInsertionPoint.
1532 //------------------------------------------------------------------------------
1533 
1534 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1535 
1536 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1537     : refOperation(beforeOperationBase.getOperation().getRef()),
1538       block((*refOperation)->getBlock()) {}
1539 
1540 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1541   PyOperation &operation = operationBase.getOperation();
1542   if (operation.isAttached())
1543     throw SetPyError(PyExc_ValueError,
1544                      "Attempt to insert operation that is already attached");
1545   block.getParentOperation()->checkValid();
1546   MlirOperation beforeOp = {nullptr};
1547   if (refOperation) {
1548     // Insert before operation.
1549     (*refOperation)->checkValid();
1550     beforeOp = (*refOperation)->get();
1551   } else {
1552     // Insert at end (before null) is only valid if the block does not
1553     // already end in a known terminator (violating this will cause assertion
1554     // failures later).
1555     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1556       throw py::index_error("Cannot insert operation at the end of a block "
1557                             "that already has a terminator. Did you mean to "
1558                             "use 'InsertionPoint.at_block_terminator(block)' "
1559                             "versus 'InsertionPoint(block)'?");
1560     }
1561   }
1562   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1563   operation.setAttached();
1564 }
1565 
1566 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1567   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1568   if (mlirOperationIsNull(firstOp)) {
1569     // Just insert at end.
1570     return PyInsertionPoint(block);
1571   }
1572 
1573   // Insert before first op.
1574   PyOperationRef firstOpRef = PyOperation::forOperation(
1575       block.getParentOperation()->getContext(), firstOp);
1576   return PyInsertionPoint{block, std::move(firstOpRef)};
1577 }
1578 
1579 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1580   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1581   if (mlirOperationIsNull(terminator))
1582     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1583   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1584       block.getParentOperation()->getContext(), terminator);
1585   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1586 }
1587 
1588 py::object PyInsertionPoint::contextEnter() {
1589   return PyThreadContextEntry::pushInsertionPoint(*this);
1590 }
1591 
1592 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1593                                    const pybind11::object &excVal,
1594                                    const pybind11::object &excTb) {
1595   PyThreadContextEntry::popInsertionPoint(*this);
1596 }
1597 
1598 //------------------------------------------------------------------------------
1599 // PyAttribute.
1600 //------------------------------------------------------------------------------
1601 
1602 bool PyAttribute::operator==(const PyAttribute &other) {
1603   return mlirAttributeEqual(attr, other.attr);
1604 }
1605 
1606 py::object PyAttribute::getCapsule() {
1607   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1608 }
1609 
1610 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1611   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1612   if (mlirAttributeIsNull(rawAttr))
1613     throw py::error_already_set();
1614   return PyAttribute(
1615       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1616 }
1617 
1618 //------------------------------------------------------------------------------
1619 // PyNamedAttribute.
1620 //------------------------------------------------------------------------------
1621 
1622 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1623     : ownedName(new std::string(std::move(ownedName))) {
1624   namedAttr = mlirNamedAttributeGet(
1625       mlirIdentifierGet(mlirAttributeGetContext(attr),
1626                         toMlirStringRef(*this->ownedName)),
1627       attr);
1628 }
1629 
1630 //------------------------------------------------------------------------------
1631 // PyType.
1632 //------------------------------------------------------------------------------
1633 
1634 bool PyType::operator==(const PyType &other) {
1635   return mlirTypeEqual(type, other.type);
1636 }
1637 
1638 py::object PyType::getCapsule() {
1639   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1640 }
1641 
1642 PyType PyType::createFromCapsule(py::object capsule) {
1643   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1644   if (mlirTypeIsNull(rawType))
1645     throw py::error_already_set();
1646   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1647                 rawType);
1648 }
1649 
1650 //------------------------------------------------------------------------------
1651 // PyValue and subclases.
1652 //------------------------------------------------------------------------------
1653 
1654 pybind11::object PyValue::getCapsule() {
1655   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1656 }
1657 
1658 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1659   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1660   if (mlirValueIsNull(value))
1661     throw py::error_already_set();
1662   MlirOperation owner;
1663   if (mlirValueIsAOpResult(value))
1664     owner = mlirOpResultGetOwner(value);
1665   if (mlirValueIsABlockArgument(value))
1666     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1667   if (mlirOperationIsNull(owner))
1668     throw py::error_already_set();
1669   MlirContext ctx = mlirOperationGetContext(owner);
1670   PyOperationRef ownerRef =
1671       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1672   return PyValue(ownerRef, value);
1673 }
1674 
1675 //------------------------------------------------------------------------------
1676 // PySymbolTable.
1677 //------------------------------------------------------------------------------
1678 
1679 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1680     : operation(operation.getOperation().getRef()) {
1681   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1682   if (mlirSymbolTableIsNull(symbolTable)) {
1683     throw py::cast_error("Operation is not a Symbol Table.");
1684   }
1685 }
1686 
1687 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1688   operation->checkValid();
1689   MlirOperation symbol = mlirSymbolTableLookup(
1690       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1691   if (mlirOperationIsNull(symbol))
1692     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1693 
1694   return PyOperation::forOperation(operation->getContext(), symbol,
1695                                    operation.getObject())
1696       ->createOpView();
1697 }
1698 
1699 void PySymbolTable::erase(PyOperationBase &symbol) {
1700   operation->checkValid();
1701   symbol.getOperation().checkValid();
1702   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1703   // The operation is also erased, so we must invalidate it. There may be Python
1704   // references to this operation so we don't want to delete it from the list of
1705   // live operations here.
1706   symbol.getOperation().valid = false;
1707 }
1708 
1709 void PySymbolTable::dunderDel(const std::string &name) {
1710   py::object operation = dunderGetItem(name);
1711   erase(py::cast<PyOperationBase &>(operation));
1712 }
1713 
1714 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1715   operation->checkValid();
1716   symbol.getOperation().checkValid();
1717   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1718       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1719   if (mlirAttributeIsNull(symbolAttr))
1720     throw py::value_error("Expected operation to have a symbol name.");
1721   return PyAttribute(
1722       symbol.getOperation().getContext(),
1723       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1724 }
1725 
1726 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1727   // Op must already be a symbol.
1728   PyOperation &operation = symbol.getOperation();
1729   operation.checkValid();
1730   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1731   MlirAttribute existingNameAttr =
1732       mlirOperationGetAttributeByName(operation.get(), attrName);
1733   if (mlirAttributeIsNull(existingNameAttr))
1734     throw py::value_error("Expected operation to have a symbol name.");
1735   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1736 }
1737 
1738 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1739                                   const std::string &name) {
1740   // Op must already be a symbol.
1741   PyOperation &operation = symbol.getOperation();
1742   operation.checkValid();
1743   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1744   MlirAttribute existingNameAttr =
1745       mlirOperationGetAttributeByName(operation.get(), attrName);
1746   if (mlirAttributeIsNull(existingNameAttr))
1747     throw py::value_error("Expected operation to have a symbol name.");
1748   MlirAttribute newNameAttr =
1749       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1750   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1751 }
1752 
1753 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1754   PyOperation &operation = symbol.getOperation();
1755   operation.checkValid();
1756   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1757   MlirAttribute existingVisAttr =
1758       mlirOperationGetAttributeByName(operation.get(), attrName);
1759   if (mlirAttributeIsNull(existingVisAttr))
1760     throw py::value_error("Expected operation to have a symbol visibility.");
1761   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1762 }
1763 
1764 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1765                                   const std::string &visibility) {
1766   if (visibility != "public" && visibility != "private" &&
1767       visibility != "nested")
1768     throw py::value_error(
1769         "Expected visibility to be 'public', 'private' or 'nested'");
1770   PyOperation &operation = symbol.getOperation();
1771   operation.checkValid();
1772   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1773   MlirAttribute existingVisAttr =
1774       mlirOperationGetAttributeByName(operation.get(), attrName);
1775   if (mlirAttributeIsNull(existingVisAttr))
1776     throw py::value_error("Expected operation to have a symbol visibility.");
1777   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1778                                                toMlirStringRef(visibility));
1779   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1780 }
1781 
1782 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1783                                          const std::string &newSymbol,
1784                                          PyOperationBase &from) {
1785   PyOperation &fromOperation = from.getOperation();
1786   fromOperation.checkValid();
1787   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1788           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1789           from.getOperation())))
1790 
1791     throw py::value_error("Symbol rename failed");
1792 }
1793 
1794 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1795                                      bool allSymUsesVisible,
1796                                      py::object callback) {
1797   PyOperation &fromOperation = from.getOperation();
1798   fromOperation.checkValid();
1799   struct UserData {
1800     PyMlirContextRef context;
1801     py::object callback;
1802     bool gotException;
1803     std::string exceptionWhat;
1804     py::object exceptionType;
1805   };
1806   UserData userData{
1807       fromOperation.getContext(), std::move(callback), false, {}, {}};
1808   mlirSymbolTableWalkSymbolTables(
1809       fromOperation.get(), allSymUsesVisible,
1810       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1811         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1812         auto pyFoundOp =
1813             PyOperation::forOperation(calleeUserData->context, foundOp);
1814         if (calleeUserData->gotException)
1815           return;
1816         try {
1817           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1818         } catch (py::error_already_set &e) {
1819           calleeUserData->gotException = true;
1820           calleeUserData->exceptionWhat = e.what();
1821           calleeUserData->exceptionType = e.type();
1822         }
1823       },
1824       static_cast<void *>(&userData));
1825   if (userData.gotException) {
1826     std::string message("Exception raised in callback: ");
1827     message.append(userData.exceptionWhat);
1828     throw std::runtime_error(message);
1829   }
1830 }
1831 
1832 namespace {
1833 /// CRTP base class for Python MLIR values that subclass Value and should be
1834 /// castable from it. The value hierarchy is one level deep and is not supposed
1835 /// to accommodate other levels unless core MLIR changes.
1836 template <typename DerivedTy>
1837 class PyConcreteValue : public PyValue {
1838 public:
1839   // Derived classes must define statics for:
1840   //   IsAFunctionTy isaFunction
1841   //   const char *pyClassName
1842   // and redefine bindDerived.
1843   using ClassTy = py::class_<DerivedTy, PyValue>;
1844   using IsAFunctionTy = bool (*)(MlirValue);
1845 
1846   PyConcreteValue() = default;
1847   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1848       : PyValue(operationRef, value) {}
1849   PyConcreteValue(PyValue &orig)
1850       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1851 
1852   /// Attempts to cast the original value to the derived type and throws on
1853   /// type mismatches.
1854   static MlirValue castFrom(PyValue &orig) {
1855     if (!DerivedTy::isaFunction(orig.get())) {
1856       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1857       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1858                                              DerivedTy::pyClassName +
1859                                              " (from " + origRepr + ")");
1860     }
1861     return orig.get();
1862   }
1863 
1864   /// Binds the Python module objects to functions of this class.
1865   static void bind(py::module &m) {
1866     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1867     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1868     cls.def_static(
1869         "isinstance",
1870         [](PyValue &otherValue) -> bool {
1871           return DerivedTy::isaFunction(otherValue);
1872         },
1873         py::arg("other_value"));
1874     DerivedTy::bindDerived(cls);
1875   }
1876 
1877   /// Implemented by derived classes to add methods to the Python subclass.
1878   static void bindDerived(ClassTy &m) {}
1879 };
1880 
1881 /// Python wrapper for MlirBlockArgument.
1882 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1883 public:
1884   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1885   static constexpr const char *pyClassName = "BlockArgument";
1886   using PyConcreteValue::PyConcreteValue;
1887 
1888   static void bindDerived(ClassTy &c) {
1889     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1890       return PyBlock(self.getParentOperation(),
1891                      mlirBlockArgumentGetOwner(self.get()));
1892     });
1893     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1894       return mlirBlockArgumentGetArgNumber(self.get());
1895     });
1896     c.def(
1897         "set_type",
1898         [](PyBlockArgument &self, PyType type) {
1899           return mlirBlockArgumentSetType(self.get(), type);
1900         },
1901         py::arg("type"));
1902   }
1903 };
1904 
1905 /// Python wrapper for MlirOpResult.
1906 class PyOpResult : public PyConcreteValue<PyOpResult> {
1907 public:
1908   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1909   static constexpr const char *pyClassName = "OpResult";
1910   using PyConcreteValue::PyConcreteValue;
1911 
1912   static void bindDerived(ClassTy &c) {
1913     c.def_property_readonly("owner", [](PyOpResult &self) {
1914       assert(
1915           mlirOperationEqual(self.getParentOperation()->get(),
1916                              mlirOpResultGetOwner(self.get())) &&
1917           "expected the owner of the value in Python to match that in the IR");
1918       return self.getParentOperation().getObject();
1919     });
1920     c.def_property_readonly("result_number", [](PyOpResult &self) {
1921       return mlirOpResultGetResultNumber(self.get());
1922     });
1923   }
1924 };
1925 
1926 /// Returns the list of types of the values held by container.
1927 template <typename Container>
1928 static std::vector<PyType> getValueTypes(Container &container,
1929                                          PyMlirContextRef &context) {
1930   std::vector<PyType> result;
1931   result.reserve(container.getNumElements());
1932   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1933     result.push_back(
1934         PyType(context, mlirValueGetType(container.getElement(i).get())));
1935   }
1936   return result;
1937 }
1938 
1939 /// A list of block arguments. Internally, these are stored as consecutive
1940 /// elements, random access is cheap. The argument list is associated with the
1941 /// operation that contains the block (detached blocks are not allowed in
1942 /// Python bindings) and extends its lifetime.
1943 class PyBlockArgumentList
1944     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1945 public:
1946   static constexpr const char *pyClassName = "BlockArgumentList";
1947 
1948   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1949                       intptr_t startIndex = 0, intptr_t length = -1,
1950                       intptr_t step = 1)
1951       : Sliceable(startIndex,
1952                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1953                   step),
1954         operation(std::move(operation)), block(block) {}
1955 
1956   /// Returns the number of arguments in the list.
1957   intptr_t getNumElements() {
1958     operation->checkValid();
1959     return mlirBlockGetNumArguments(block);
1960   }
1961 
1962   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1963   PyBlockArgument getElement(intptr_t pos) {
1964     MlirValue argument = mlirBlockGetArgument(block, pos);
1965     return PyBlockArgument(operation, argument);
1966   }
1967 
1968   /// Returns a sublist of this list.
1969   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1970                             intptr_t step) {
1971     return PyBlockArgumentList(operation, block, startIndex, length, step);
1972   }
1973 
1974   static void bindDerived(ClassTy &c) {
1975     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1976       return getValueTypes(self, self.operation->getContext());
1977     });
1978   }
1979 
1980 private:
1981   PyOperationRef operation;
1982   MlirBlock block;
1983 };
1984 
1985 /// A list of operation operands. Internally, these are stored as consecutive
1986 /// elements, random access is cheap. The result list is associated with the
1987 /// operation whose results these are, and extends the lifetime of this
1988 /// operation.
1989 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1990 public:
1991   static constexpr const char *pyClassName = "OpOperandList";
1992 
1993   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1994                   intptr_t length = -1, intptr_t step = 1)
1995       : Sliceable(startIndex,
1996                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1997                                : length,
1998                   step),
1999         operation(operation) {}
2000 
2001   intptr_t getNumElements() {
2002     operation->checkValid();
2003     return mlirOperationGetNumOperands(operation->get());
2004   }
2005 
2006   PyValue getElement(intptr_t pos) {
2007     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2008     MlirOperation owner;
2009     if (mlirValueIsAOpResult(operand))
2010       owner = mlirOpResultGetOwner(operand);
2011     else if (mlirValueIsABlockArgument(operand))
2012       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2013     else
2014       assert(false && "Value must be an block arg or op result.");
2015     PyOperationRef pyOwner =
2016         PyOperation::forOperation(operation->getContext(), owner);
2017     return PyValue(pyOwner, operand);
2018   }
2019 
2020   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2021     return PyOpOperandList(operation, startIndex, length, step);
2022   }
2023 
2024   void dunderSetItem(intptr_t index, PyValue value) {
2025     index = wrapIndex(index);
2026     mlirOperationSetOperand(operation->get(), index, value.get());
2027   }
2028 
2029   static void bindDerived(ClassTy &c) {
2030     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2031   }
2032 
2033 private:
2034   PyOperationRef operation;
2035 };
2036 
2037 /// A list of operation results. Internally, these are stored as consecutive
2038 /// elements, random access is cheap. The result list is associated with the
2039 /// operation whose results these are, and extends the lifetime of this
2040 /// operation.
2041 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2042 public:
2043   static constexpr const char *pyClassName = "OpResultList";
2044 
2045   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2046                  intptr_t length = -1, intptr_t step = 1)
2047       : Sliceable(startIndex,
2048                   length == -1 ? mlirOperationGetNumResults(operation->get())
2049                                : length,
2050                   step),
2051         operation(operation) {}
2052 
2053   intptr_t getNumElements() {
2054     operation->checkValid();
2055     return mlirOperationGetNumResults(operation->get());
2056   }
2057 
2058   PyOpResult getElement(intptr_t index) {
2059     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2060     return PyOpResult(value);
2061   }
2062 
2063   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2064     return PyOpResultList(operation, startIndex, length, step);
2065   }
2066 
2067   static void bindDerived(ClassTy &c) {
2068     c.def_property_readonly("types", [](PyOpResultList &self) {
2069       return getValueTypes(self, self.operation->getContext());
2070     });
2071   }
2072 
2073 private:
2074   PyOperationRef operation;
2075 };
2076 
2077 /// A list of operation attributes. Can be indexed by name, producing
2078 /// attributes, or by index, producing named attributes.
2079 class PyOpAttributeMap {
2080 public:
2081   PyOpAttributeMap(PyOperationRef operation)
2082       : operation(std::move(operation)) {}
2083 
2084   PyAttribute dunderGetItemNamed(const std::string &name) {
2085     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2086                                                          toMlirStringRef(name));
2087     if (mlirAttributeIsNull(attr)) {
2088       throw SetPyError(PyExc_KeyError,
2089                        "attempt to access a non-existent attribute");
2090     }
2091     return PyAttribute(operation->getContext(), attr);
2092   }
2093 
2094   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2095     if (index < 0 || index >= dunderLen()) {
2096       throw SetPyError(PyExc_IndexError,
2097                        "attempt to access out of bounds attribute");
2098     }
2099     MlirNamedAttribute namedAttr =
2100         mlirOperationGetAttribute(operation->get(), index);
2101     return PyNamedAttribute(
2102         namedAttr.attribute,
2103         std::string(mlirIdentifierStr(namedAttr.name).data,
2104                     mlirIdentifierStr(namedAttr.name).length));
2105   }
2106 
2107   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2108     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2109                                     attr);
2110   }
2111 
2112   void dunderDelItem(const std::string &name) {
2113     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2114                                                      toMlirStringRef(name));
2115     if (!removed)
2116       throw SetPyError(PyExc_KeyError,
2117                        "attempt to delete a non-existent attribute");
2118   }
2119 
2120   intptr_t dunderLen() {
2121     return mlirOperationGetNumAttributes(operation->get());
2122   }
2123 
2124   bool dunderContains(const std::string &name) {
2125     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2126         operation->get(), toMlirStringRef(name)));
2127   }
2128 
2129   static void bind(py::module &m) {
2130     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2131         .def("__contains__", &PyOpAttributeMap::dunderContains)
2132         .def("__len__", &PyOpAttributeMap::dunderLen)
2133         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2134         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2135         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2136         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2137   }
2138 
2139 private:
2140   PyOperationRef operation;
2141 };
2142 
2143 } // namespace
2144 
2145 //------------------------------------------------------------------------------
2146 // Populates the core exports of the 'ir' submodule.
2147 //------------------------------------------------------------------------------
2148 
2149 void mlir::python::populateIRCore(py::module &m) {
2150   //----------------------------------------------------------------------------
2151   // Enums.
2152   //----------------------------------------------------------------------------
2153   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2154       .value("ERROR", MlirDiagnosticError)
2155       .value("WARNING", MlirDiagnosticWarning)
2156       .value("NOTE", MlirDiagnosticNote)
2157       .value("REMARK", MlirDiagnosticRemark);
2158 
2159   //----------------------------------------------------------------------------
2160   // Mapping of Diagnostics.
2161   //----------------------------------------------------------------------------
2162   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2163       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2164       .def_property_readonly("location", &PyDiagnostic::getLocation)
2165       .def_property_readonly("message", &PyDiagnostic::getMessage)
2166       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2167       .def("__str__", [](PyDiagnostic &self) -> py::str {
2168         if (!self.isValid())
2169           return "<Invalid Diagnostic>";
2170         return self.getMessage();
2171       });
2172 
2173   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2174       .def("detach", &PyDiagnosticHandler::detach)
2175       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2176       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2177       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2178       .def("__exit__", &PyDiagnosticHandler::contextExit);
2179 
2180   //----------------------------------------------------------------------------
2181   // Mapping of MlirContext.
2182   //----------------------------------------------------------------------------
2183   py::class_<PyMlirContext>(m, "Context", py::module_local())
2184       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2185       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2186       .def("_get_context_again",
2187            [](PyMlirContext &self) {
2188              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2189              return ref.releaseObject();
2190            })
2191       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2192       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2193       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2194                              &PyMlirContext::getCapsule)
2195       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2196       .def("__enter__", &PyMlirContext::contextEnter)
2197       .def("__exit__", &PyMlirContext::contextExit)
2198       .def_property_readonly_static(
2199           "current",
2200           [](py::object & /*class*/) {
2201             auto *context = PyThreadContextEntry::getDefaultContext();
2202             if (!context)
2203               throw SetPyError(PyExc_ValueError, "No current Context");
2204             return context;
2205           },
2206           "Gets the Context bound to the current thread or raises ValueError")
2207       .def_property_readonly(
2208           "dialects",
2209           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2210           "Gets a container for accessing dialects by name")
2211       .def_property_readonly(
2212           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2213           "Alias for 'dialect'")
2214       .def(
2215           "get_dialect_descriptor",
2216           [=](PyMlirContext &self, std::string &name) {
2217             MlirDialect dialect = mlirContextGetOrLoadDialect(
2218                 self.get(), {name.data(), name.size()});
2219             if (mlirDialectIsNull(dialect)) {
2220               throw SetPyError(PyExc_ValueError,
2221                                Twine("Dialect '") + name + "' not found");
2222             }
2223             return PyDialectDescriptor(self.getRef(), dialect);
2224           },
2225           py::arg("dialect_name"),
2226           "Gets or loads a dialect by name, returning its descriptor object")
2227       .def_property(
2228           "allow_unregistered_dialects",
2229           [](PyMlirContext &self) -> bool {
2230             return mlirContextGetAllowUnregisteredDialects(self.get());
2231           },
2232           [](PyMlirContext &self, bool value) {
2233             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2234           })
2235       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2236            py::arg("callback"),
2237            "Attaches a diagnostic handler that will receive callbacks")
2238       .def(
2239           "enable_multithreading",
2240           [](PyMlirContext &self, bool enable) {
2241             mlirContextEnableMultithreading(self.get(), enable);
2242           },
2243           py::arg("enable"))
2244       .def(
2245           "is_registered_operation",
2246           [](PyMlirContext &self, std::string &name) {
2247             return mlirContextIsRegisteredOperation(
2248                 self.get(), MlirStringRef{name.data(), name.size()});
2249           },
2250           py::arg("operation_name"));
2251 
2252   //----------------------------------------------------------------------------
2253   // Mapping of PyDialectDescriptor
2254   //----------------------------------------------------------------------------
2255   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2256       .def_property_readonly("namespace",
2257                              [](PyDialectDescriptor &self) {
2258                                MlirStringRef ns =
2259                                    mlirDialectGetNamespace(self.get());
2260                                return py::str(ns.data, ns.length);
2261                              })
2262       .def("__repr__", [](PyDialectDescriptor &self) {
2263         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2264         std::string repr("<DialectDescriptor ");
2265         repr.append(ns.data, ns.length);
2266         repr.append(">");
2267         return repr;
2268       });
2269 
2270   //----------------------------------------------------------------------------
2271   // Mapping of PyDialects
2272   //----------------------------------------------------------------------------
2273   py::class_<PyDialects>(m, "Dialects", py::module_local())
2274       .def("__getitem__",
2275            [=](PyDialects &self, std::string keyName) {
2276              MlirDialect dialect =
2277                  self.getDialectForKey(keyName, /*attrError=*/false);
2278              py::object descriptor =
2279                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2280              return createCustomDialectWrapper(keyName, std::move(descriptor));
2281            })
2282       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2283         MlirDialect dialect =
2284             self.getDialectForKey(attrName, /*attrError=*/true);
2285         py::object descriptor =
2286             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2287         return createCustomDialectWrapper(attrName, std::move(descriptor));
2288       });
2289 
2290   //----------------------------------------------------------------------------
2291   // Mapping of PyDialect
2292   //----------------------------------------------------------------------------
2293   py::class_<PyDialect>(m, "Dialect", py::module_local())
2294       .def(py::init<py::object>(), py::arg("descriptor"))
2295       .def_property_readonly(
2296           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2297       .def("__repr__", [](py::object self) {
2298         auto clazz = self.attr("__class__");
2299         return py::str("<Dialect ") +
2300                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2301                clazz.attr("__module__") + py::str(".") +
2302                clazz.attr("__name__") + py::str(")>");
2303       });
2304 
2305   //----------------------------------------------------------------------------
2306   // Mapping of Location
2307   //----------------------------------------------------------------------------
2308   py::class_<PyLocation>(m, "Location", py::module_local())
2309       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2310       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2311       .def("__enter__", &PyLocation::contextEnter)
2312       .def("__exit__", &PyLocation::contextExit)
2313       .def("__eq__",
2314            [](PyLocation &self, PyLocation &other) -> bool {
2315              return mlirLocationEqual(self, other);
2316            })
2317       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2318       .def_property_readonly_static(
2319           "current",
2320           [](py::object & /*class*/) {
2321             auto *loc = PyThreadContextEntry::getDefaultLocation();
2322             if (!loc)
2323               throw SetPyError(PyExc_ValueError, "No current Location");
2324             return loc;
2325           },
2326           "Gets the Location bound to the current thread or raises ValueError")
2327       .def_static(
2328           "unknown",
2329           [](DefaultingPyMlirContext context) {
2330             return PyLocation(context->getRef(),
2331                               mlirLocationUnknownGet(context->get()));
2332           },
2333           py::arg("context") = py::none(),
2334           "Gets a Location representing an unknown location")
2335       .def_static(
2336           "callsite",
2337           [](PyLocation callee, const std::vector<PyLocation> &frames,
2338              DefaultingPyMlirContext context) {
2339             if (frames.empty())
2340               throw py::value_error("No caller frames provided");
2341             MlirLocation caller = frames.back().get();
2342             for (const PyLocation &frame :
2343                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2344               caller = mlirLocationCallSiteGet(frame.get(), caller);
2345             return PyLocation(context->getRef(),
2346                               mlirLocationCallSiteGet(callee.get(), caller));
2347           },
2348           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2349           kContextGetCallSiteLocationDocstring)
2350       .def_static(
2351           "file",
2352           [](std::string filename, int line, int col,
2353              DefaultingPyMlirContext context) {
2354             return PyLocation(
2355                 context->getRef(),
2356                 mlirLocationFileLineColGet(
2357                     context->get(), toMlirStringRef(filename), line, col));
2358           },
2359           py::arg("filename"), py::arg("line"), py::arg("col"),
2360           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2361       .def_static(
2362           "fused",
2363           [](const std::vector<PyLocation> &pyLocations,
2364              llvm::Optional<PyAttribute> metadata,
2365              DefaultingPyMlirContext context) {
2366             llvm::SmallVector<MlirLocation, 4> locations;
2367             locations.reserve(pyLocations.size());
2368             for (auto &pyLocation : pyLocations)
2369               locations.push_back(pyLocation.get());
2370             MlirLocation location = mlirLocationFusedGet(
2371                 context->get(), locations.size(), locations.data(),
2372                 metadata ? metadata->get() : MlirAttribute{0});
2373             return PyLocation(context->getRef(), location);
2374           },
2375           py::arg("locations"), py::arg("metadata") = py::none(),
2376           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2377       .def_static(
2378           "name",
2379           [](std::string name, llvm::Optional<PyLocation> childLoc,
2380              DefaultingPyMlirContext context) {
2381             return PyLocation(
2382                 context->getRef(),
2383                 mlirLocationNameGet(
2384                     context->get(), toMlirStringRef(name),
2385                     childLoc ? childLoc->get()
2386                              : mlirLocationUnknownGet(context->get())));
2387           },
2388           py::arg("name"), py::arg("childLoc") = py::none(),
2389           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2390       .def_property_readonly(
2391           "context",
2392           [](PyLocation &self) { return self.getContext().getObject(); },
2393           "Context that owns the Location")
2394       .def(
2395           "emit_error",
2396           [](PyLocation &self, std::string message) {
2397             mlirEmitError(self, message.c_str());
2398           },
2399           py::arg("message"), "Emits an error at this location")
2400       .def("__repr__", [](PyLocation &self) {
2401         PyPrintAccumulator printAccum;
2402         mlirLocationPrint(self, printAccum.getCallback(),
2403                           printAccum.getUserData());
2404         return printAccum.join();
2405       });
2406 
2407   //----------------------------------------------------------------------------
2408   // Mapping of Module
2409   //----------------------------------------------------------------------------
2410   py::class_<PyModule>(m, "Module", py::module_local())
2411       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2412       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2413       .def_static(
2414           "parse",
2415           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2416             MlirModule module = mlirModuleCreateParse(
2417                 context->get(), toMlirStringRef(moduleAsm));
2418             // TODO: Rework error reporting once diagnostic engine is exposed
2419             // in C API.
2420             if (mlirModuleIsNull(module)) {
2421               throw SetPyError(
2422                   PyExc_ValueError,
2423                   "Unable to parse module assembly (see diagnostics)");
2424             }
2425             return PyModule::forModule(module).releaseObject();
2426           },
2427           py::arg("asm"), py::arg("context") = py::none(),
2428           kModuleParseDocstring)
2429       .def_static(
2430           "create",
2431           [](DefaultingPyLocation loc) {
2432             MlirModule module = mlirModuleCreateEmpty(loc);
2433             return PyModule::forModule(module).releaseObject();
2434           },
2435           py::arg("loc") = py::none(), "Creates an empty module")
2436       .def_property_readonly(
2437           "context",
2438           [](PyModule &self) { return self.getContext().getObject(); },
2439           "Context that created the Module")
2440       .def_property_readonly(
2441           "operation",
2442           [](PyModule &self) {
2443             return PyOperation::forOperation(self.getContext(),
2444                                              mlirModuleGetOperation(self.get()),
2445                                              self.getRef().releaseObject())
2446                 .releaseObject();
2447           },
2448           "Accesses the module as an operation")
2449       .def_property_readonly(
2450           "body",
2451           [](PyModule &self) {
2452             PyOperationRef moduleOp = PyOperation::forOperation(
2453                 self.getContext(), mlirModuleGetOperation(self.get()),
2454                 self.getRef().releaseObject());
2455             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2456             return returnBlock;
2457           },
2458           "Return the block for this module")
2459       .def(
2460           "dump",
2461           [](PyModule &self) {
2462             mlirOperationDump(mlirModuleGetOperation(self.get()));
2463           },
2464           kDumpDocstring)
2465       .def(
2466           "__str__",
2467           [](py::object self) {
2468             // Defer to the operation's __str__.
2469             return self.attr("operation").attr("__str__")();
2470           },
2471           kOperationStrDunderDocstring);
2472 
2473   //----------------------------------------------------------------------------
2474   // Mapping of Operation.
2475   //----------------------------------------------------------------------------
2476   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2477       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2478                              [](PyOperationBase &self) {
2479                                return self.getOperation().getCapsule();
2480                              })
2481       .def("__eq__",
2482            [](PyOperationBase &self, PyOperationBase &other) {
2483              return &self.getOperation() == &other.getOperation();
2484            })
2485       .def("__eq__",
2486            [](PyOperationBase &self, py::object other) { return false; })
2487       .def("__hash__",
2488            [](PyOperationBase &self) {
2489              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2490            })
2491       .def_property_readonly("attributes",
2492                              [](PyOperationBase &self) {
2493                                return PyOpAttributeMap(
2494                                    self.getOperation().getRef());
2495                              })
2496       .def_property_readonly("operands",
2497                              [](PyOperationBase &self) {
2498                                return PyOpOperandList(
2499                                    self.getOperation().getRef());
2500                              })
2501       .def_property_readonly("regions",
2502                              [](PyOperationBase &self) {
2503                                return PyRegionList(
2504                                    self.getOperation().getRef());
2505                              })
2506       .def_property_readonly(
2507           "results",
2508           [](PyOperationBase &self) {
2509             return PyOpResultList(self.getOperation().getRef());
2510           },
2511           "Returns the list of Operation results.")
2512       .def_property_readonly(
2513           "result",
2514           [](PyOperationBase &self) {
2515             auto &operation = self.getOperation();
2516             auto numResults = mlirOperationGetNumResults(operation);
2517             if (numResults != 1) {
2518               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2519               throw SetPyError(
2520                   PyExc_ValueError,
2521                   Twine("Cannot call .result on operation ") +
2522                       StringRef(name.data, name.length) + " which has " +
2523                       Twine(numResults) +
2524                       " results (it is only valid for operations with a "
2525                       "single result)");
2526             }
2527             return PyOpResult(operation.getRef(),
2528                               mlirOperationGetResult(operation, 0));
2529           },
2530           "Shortcut to get an op result if it has only one (throws an error "
2531           "otherwise).")
2532       .def_property_readonly(
2533           "location",
2534           [](PyOperationBase &self) {
2535             PyOperation &operation = self.getOperation();
2536             return PyLocation(operation.getContext(),
2537                               mlirOperationGetLocation(operation.get()));
2538           },
2539           "Returns the source location the operation was defined or derived "
2540           "from.")
2541       .def(
2542           "__str__",
2543           [](PyOperationBase &self) {
2544             return self.getAsm(/*binary=*/false,
2545                                /*largeElementsLimit=*/llvm::None,
2546                                /*enableDebugInfo=*/false,
2547                                /*prettyDebugInfo=*/false,
2548                                /*printGenericOpForm=*/false,
2549                                /*useLocalScope=*/false,
2550                                /*assumeVerified=*/false);
2551           },
2552           "Returns the assembly form of the operation.")
2553       .def("print", &PyOperationBase::print,
2554            // Careful: Lots of arguments must match up with print method.
2555            py::arg("file") = py::none(), py::arg("binary") = false,
2556            py::arg("large_elements_limit") = py::none(),
2557            py::arg("enable_debug_info") = false,
2558            py::arg("pretty_debug_info") = false,
2559            py::arg("print_generic_op_form") = false,
2560            py::arg("use_local_scope") = false,
2561            py::arg("assume_verified") = false, kOperationPrintDocstring)
2562       .def("get_asm", &PyOperationBase::getAsm,
2563            // Careful: Lots of arguments must match up with get_asm method.
2564            py::arg("binary") = false,
2565            py::arg("large_elements_limit") = py::none(),
2566            py::arg("enable_debug_info") = false,
2567            py::arg("pretty_debug_info") = false,
2568            py::arg("print_generic_op_form") = false,
2569            py::arg("use_local_scope") = false,
2570            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2571       .def(
2572           "verify",
2573           [](PyOperationBase &self) {
2574             return mlirOperationVerify(self.getOperation());
2575           },
2576           "Verify the operation and return true if it passes, false if it "
2577           "fails.")
2578       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2579            "Puts self immediately after the other operation in its parent "
2580            "block.")
2581       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2582            "Puts self immediately before the other operation in its parent "
2583            "block.")
2584       .def(
2585           "detach_from_parent",
2586           [](PyOperationBase &self) {
2587             PyOperation &operation = self.getOperation();
2588             operation.checkValid();
2589             if (!operation.isAttached())
2590               throw py::value_error("Detached operation has no parent.");
2591 
2592             operation.detachFromParent();
2593             return operation.createOpView();
2594           },
2595           "Detaches the operation from its parent block.");
2596 
2597   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2598       .def_static("create", &PyOperation::create, py::arg("name"),
2599                   py::arg("results") = py::none(),
2600                   py::arg("operands") = py::none(),
2601                   py::arg("attributes") = py::none(),
2602                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2603                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2604                   kOperationCreateDocstring)
2605       .def_property_readonly("parent",
2606                              [](PyOperation &self) -> py::object {
2607                                auto parent = self.getParentOperation();
2608                                if (parent)
2609                                  return parent->getObject();
2610                                return py::none();
2611                              })
2612       .def("erase", &PyOperation::erase)
2613       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2614                              &PyOperation::getCapsule)
2615       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2616       .def_property_readonly("name",
2617                              [](PyOperation &self) {
2618                                self.checkValid();
2619                                MlirOperation operation = self.get();
2620                                MlirStringRef name = mlirIdentifierStr(
2621                                    mlirOperationGetName(operation));
2622                                return py::str(name.data, name.length);
2623                              })
2624       .def_property_readonly(
2625           "context",
2626           [](PyOperation &self) {
2627             self.checkValid();
2628             return self.getContext().getObject();
2629           },
2630           "Context that owns the Operation")
2631       .def_property_readonly("opview", &PyOperation::createOpView);
2632 
2633   auto opViewClass =
2634       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2635           .def(py::init<py::object>(), py::arg("operation"))
2636           .def_property_readonly("operation", &PyOpView::getOperationObject)
2637           .def_property_readonly(
2638               "context",
2639               [](PyOpView &self) {
2640                 return self.getOperation().getContext().getObject();
2641               },
2642               "Context that owns the Operation")
2643           .def("__str__", [](PyOpView &self) {
2644             return py::str(self.getOperationObject());
2645           });
2646   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2647   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2648   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2649   opViewClass.attr("build_generic") = classmethod(
2650       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2651       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2652       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2653       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2654       "Builds a specific, generated OpView based on class level attributes.");
2655 
2656   //----------------------------------------------------------------------------
2657   // Mapping of PyRegion.
2658   //----------------------------------------------------------------------------
2659   py::class_<PyRegion>(m, "Region", py::module_local())
2660       .def_property_readonly(
2661           "blocks",
2662           [](PyRegion &self) {
2663             return PyBlockList(self.getParentOperation(), self.get());
2664           },
2665           "Returns a forward-optimized sequence of blocks.")
2666       .def_property_readonly(
2667           "owner",
2668           [](PyRegion &self) {
2669             return self.getParentOperation()->createOpView();
2670           },
2671           "Returns the operation owning this region.")
2672       .def(
2673           "__iter__",
2674           [](PyRegion &self) {
2675             self.checkValid();
2676             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2677             return PyBlockIterator(self.getParentOperation(), firstBlock);
2678           },
2679           "Iterates over blocks in the region.")
2680       .def("__eq__",
2681            [](PyRegion &self, PyRegion &other) {
2682              return self.get().ptr == other.get().ptr;
2683            })
2684       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2685 
2686   //----------------------------------------------------------------------------
2687   // Mapping of PyBlock.
2688   //----------------------------------------------------------------------------
2689   py::class_<PyBlock>(m, "Block", py::module_local())
2690       .def_property_readonly(
2691           "owner",
2692           [](PyBlock &self) {
2693             return self.getParentOperation()->createOpView();
2694           },
2695           "Returns the owning operation of this block.")
2696       .def_property_readonly(
2697           "region",
2698           [](PyBlock &self) {
2699             MlirRegion region = mlirBlockGetParentRegion(self.get());
2700             return PyRegion(self.getParentOperation(), region);
2701           },
2702           "Returns the owning region of this block.")
2703       .def_property_readonly(
2704           "arguments",
2705           [](PyBlock &self) {
2706             return PyBlockArgumentList(self.getParentOperation(), self.get());
2707           },
2708           "Returns a list of block arguments.")
2709       .def_property_readonly(
2710           "operations",
2711           [](PyBlock &self) {
2712             return PyOperationList(self.getParentOperation(), self.get());
2713           },
2714           "Returns a forward-optimized sequence of operations.")
2715       .def_static(
2716           "create_at_start",
2717           [](PyRegion &parent, py::list pyArgTypes) {
2718             parent.checkValid();
2719             llvm::SmallVector<MlirType, 4> argTypes;
2720             argTypes.reserve(pyArgTypes.size());
2721             for (auto &pyArg : pyArgTypes) {
2722               argTypes.push_back(pyArg.cast<PyType &>());
2723             }
2724 
2725             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2726             mlirRegionInsertOwnedBlock(parent, 0, block);
2727             return PyBlock(parent.getParentOperation(), block);
2728           },
2729           py::arg("parent"), py::arg("arg_types") = py::list(),
2730           "Creates and returns a new Block at the beginning of the given "
2731           "region (with given argument types).")
2732       .def(
2733           "create_before",
2734           [](PyBlock &self, py::args pyArgTypes) {
2735             self.checkValid();
2736             llvm::SmallVector<MlirType, 4> argTypes;
2737             argTypes.reserve(pyArgTypes.size());
2738             for (auto &pyArg : pyArgTypes) {
2739               argTypes.push_back(pyArg.cast<PyType &>());
2740             }
2741 
2742             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2743             MlirRegion region = mlirBlockGetParentRegion(self.get());
2744             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2745             return PyBlock(self.getParentOperation(), block);
2746           },
2747           "Creates and returns a new Block before this block "
2748           "(with given argument types).")
2749       .def(
2750           "create_after",
2751           [](PyBlock &self, py::args pyArgTypes) {
2752             self.checkValid();
2753             llvm::SmallVector<MlirType, 4> argTypes;
2754             argTypes.reserve(pyArgTypes.size());
2755             for (auto &pyArg : pyArgTypes) {
2756               argTypes.push_back(pyArg.cast<PyType &>());
2757             }
2758 
2759             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2760             MlirRegion region = mlirBlockGetParentRegion(self.get());
2761             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2762             return PyBlock(self.getParentOperation(), block);
2763           },
2764           "Creates and returns a new Block after this block "
2765           "(with given argument types).")
2766       .def(
2767           "__iter__",
2768           [](PyBlock &self) {
2769             self.checkValid();
2770             MlirOperation firstOperation =
2771                 mlirBlockGetFirstOperation(self.get());
2772             return PyOperationIterator(self.getParentOperation(),
2773                                        firstOperation);
2774           },
2775           "Iterates over operations in the block.")
2776       .def("__eq__",
2777            [](PyBlock &self, PyBlock &other) {
2778              return self.get().ptr == other.get().ptr;
2779            })
2780       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2781       .def(
2782           "__str__",
2783           [](PyBlock &self) {
2784             self.checkValid();
2785             PyPrintAccumulator printAccum;
2786             mlirBlockPrint(self.get(), printAccum.getCallback(),
2787                            printAccum.getUserData());
2788             return printAccum.join();
2789           },
2790           "Returns the assembly form of the block.")
2791       .def(
2792           "append",
2793           [](PyBlock &self, PyOperationBase &operation) {
2794             if (operation.getOperation().isAttached())
2795               operation.getOperation().detachFromParent();
2796 
2797             MlirOperation mlirOperation = operation.getOperation().get();
2798             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2799             operation.getOperation().setAttached(
2800                 self.getParentOperation().getObject());
2801           },
2802           py::arg("operation"),
2803           "Appends an operation to this block. If the operation is currently "
2804           "in another block, it will be moved.");
2805 
2806   //----------------------------------------------------------------------------
2807   // Mapping of PyInsertionPoint.
2808   //----------------------------------------------------------------------------
2809 
2810   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2811       .def(py::init<PyBlock &>(), py::arg("block"),
2812            "Inserts after the last operation but still inside the block.")
2813       .def("__enter__", &PyInsertionPoint::contextEnter)
2814       .def("__exit__", &PyInsertionPoint::contextExit)
2815       .def_property_readonly_static(
2816           "current",
2817           [](py::object & /*class*/) {
2818             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2819             if (!ip)
2820               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2821             return ip;
2822           },
2823           "Gets the InsertionPoint bound to the current thread or raises "
2824           "ValueError if none has been set")
2825       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2826            "Inserts before a referenced operation.")
2827       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2828                   py::arg("block"), "Inserts at the beginning of the block.")
2829       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2830                   py::arg("block"), "Inserts before the block terminator.")
2831       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2832            "Inserts an operation.")
2833       .def_property_readonly(
2834           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2835           "Returns the block that this InsertionPoint points to.");
2836 
2837   //----------------------------------------------------------------------------
2838   // Mapping of PyAttribute.
2839   //----------------------------------------------------------------------------
2840   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2841       // Delegate to the PyAttribute copy constructor, which will also lifetime
2842       // extend the backing context which owns the MlirAttribute.
2843       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2844            "Casts the passed attribute to the generic Attribute")
2845       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2846                              &PyAttribute::getCapsule)
2847       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2848       .def_static(
2849           "parse",
2850           [](std::string attrSpec, DefaultingPyMlirContext context) {
2851             MlirAttribute type = mlirAttributeParseGet(
2852                 context->get(), toMlirStringRef(attrSpec));
2853             // TODO: Rework error reporting once diagnostic engine is exposed
2854             // in C API.
2855             if (mlirAttributeIsNull(type)) {
2856               throw SetPyError(PyExc_ValueError,
2857                                Twine("Unable to parse attribute: '") +
2858                                    attrSpec + "'");
2859             }
2860             return PyAttribute(context->getRef(), type);
2861           },
2862           py::arg("asm"), py::arg("context") = py::none(),
2863           "Parses an attribute from an assembly form")
2864       .def_property_readonly(
2865           "context",
2866           [](PyAttribute &self) { return self.getContext().getObject(); },
2867           "Context that owns the Attribute")
2868       .def_property_readonly("type",
2869                              [](PyAttribute &self) {
2870                                return PyType(self.getContext()->getRef(),
2871                                              mlirAttributeGetType(self));
2872                              })
2873       .def(
2874           "get_named",
2875           [](PyAttribute &self, std::string name) {
2876             return PyNamedAttribute(self, std::move(name));
2877           },
2878           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2879       .def("__eq__",
2880            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2881       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2882       .def("__hash__",
2883            [](PyAttribute &self) {
2884              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2885            })
2886       .def(
2887           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2888           kDumpDocstring)
2889       .def(
2890           "__str__",
2891           [](PyAttribute &self) {
2892             PyPrintAccumulator printAccum;
2893             mlirAttributePrint(self, printAccum.getCallback(),
2894                                printAccum.getUserData());
2895             return printAccum.join();
2896           },
2897           "Returns the assembly form of the Attribute.")
2898       .def("__repr__", [](PyAttribute &self) {
2899         // Generally, assembly formats are not printed for __repr__ because
2900         // this can cause exceptionally long debug output and exceptions.
2901         // However, attribute values are generally considered useful and are
2902         // printed. This may need to be re-evaluated if debug dumps end up
2903         // being excessive.
2904         PyPrintAccumulator printAccum;
2905         printAccum.parts.append("Attribute(");
2906         mlirAttributePrint(self, printAccum.getCallback(),
2907                            printAccum.getUserData());
2908         printAccum.parts.append(")");
2909         return printAccum.join();
2910       });
2911 
2912   //----------------------------------------------------------------------------
2913   // Mapping of PyNamedAttribute
2914   //----------------------------------------------------------------------------
2915   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2916       .def("__repr__",
2917            [](PyNamedAttribute &self) {
2918              PyPrintAccumulator printAccum;
2919              printAccum.parts.append("NamedAttribute(");
2920              printAccum.parts.append(
2921                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2922                          mlirIdentifierStr(self.namedAttr.name).length));
2923              printAccum.parts.append("=");
2924              mlirAttributePrint(self.namedAttr.attribute,
2925                                 printAccum.getCallback(),
2926                                 printAccum.getUserData());
2927              printAccum.parts.append(")");
2928              return printAccum.join();
2929            })
2930       .def_property_readonly(
2931           "name",
2932           [](PyNamedAttribute &self) {
2933             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2934                            mlirIdentifierStr(self.namedAttr.name).length);
2935           },
2936           "The name of the NamedAttribute binding")
2937       .def_property_readonly(
2938           "attr",
2939           [](PyNamedAttribute &self) {
2940             // TODO: When named attribute is removed/refactored, also remove
2941             // this constructor (it does an inefficient table lookup).
2942             auto contextRef = PyMlirContext::forContext(
2943                 mlirAttributeGetContext(self.namedAttr.attribute));
2944             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2945           },
2946           py::keep_alive<0, 1>(),
2947           "The underlying generic attribute of the NamedAttribute binding");
2948 
2949   //----------------------------------------------------------------------------
2950   // Mapping of PyType.
2951   //----------------------------------------------------------------------------
2952   py::class_<PyType>(m, "Type", py::module_local())
2953       // Delegate to the PyType copy constructor, which will also lifetime
2954       // extend the backing context which owns the MlirType.
2955       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2956            "Casts the passed type to the generic Type")
2957       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2958       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2959       .def_static(
2960           "parse",
2961           [](std::string typeSpec, DefaultingPyMlirContext context) {
2962             MlirType type =
2963                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2964             // TODO: Rework error reporting once diagnostic engine is exposed
2965             // in C API.
2966             if (mlirTypeIsNull(type)) {
2967               throw SetPyError(PyExc_ValueError,
2968                                Twine("Unable to parse type: '") + typeSpec +
2969                                    "'");
2970             }
2971             return PyType(context->getRef(), type);
2972           },
2973           py::arg("asm"), py::arg("context") = py::none(),
2974           kContextParseTypeDocstring)
2975       .def_property_readonly(
2976           "context", [](PyType &self) { return self.getContext().getObject(); },
2977           "Context that owns the Type")
2978       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2979       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2980       .def("__hash__",
2981            [](PyType &self) {
2982              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2983            })
2984       .def(
2985           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2986       .def(
2987           "__str__",
2988           [](PyType &self) {
2989             PyPrintAccumulator printAccum;
2990             mlirTypePrint(self, printAccum.getCallback(),
2991                           printAccum.getUserData());
2992             return printAccum.join();
2993           },
2994           "Returns the assembly form of the type.")
2995       .def("__repr__", [](PyType &self) {
2996         // Generally, assembly formats are not printed for __repr__ because
2997         // this can cause exceptionally long debug output and exceptions.
2998         // However, types are an exception as they typically have compact
2999         // assembly forms and printing them is useful.
3000         PyPrintAccumulator printAccum;
3001         printAccum.parts.append("Type(");
3002         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3003         printAccum.parts.append(")");
3004         return printAccum.join();
3005       });
3006 
3007   //----------------------------------------------------------------------------
3008   // Mapping of Value.
3009   //----------------------------------------------------------------------------
3010   py::class_<PyValue>(m, "Value", py::module_local())
3011       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3012       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3013       .def_property_readonly(
3014           "context",
3015           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3016           "Context in which the value lives.")
3017       .def(
3018           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3019           kDumpDocstring)
3020       .def_property_readonly(
3021           "owner",
3022           [](PyValue &self) {
3023             assert(mlirOperationEqual(self.getParentOperation()->get(),
3024                                       mlirOpResultGetOwner(self.get())) &&
3025                    "expected the owner of the value in Python to match that in "
3026                    "the IR");
3027             return self.getParentOperation().getObject();
3028           })
3029       .def("__eq__",
3030            [](PyValue &self, PyValue &other) {
3031              return self.get().ptr == other.get().ptr;
3032            })
3033       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3034       .def("__hash__",
3035            [](PyValue &self) {
3036              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3037            })
3038       .def(
3039           "__str__",
3040           [](PyValue &self) {
3041             PyPrintAccumulator printAccum;
3042             printAccum.parts.append("Value(");
3043             mlirValuePrint(self.get(), printAccum.getCallback(),
3044                            printAccum.getUserData());
3045             printAccum.parts.append(")");
3046             return printAccum.join();
3047           },
3048           kValueDunderStrDocstring)
3049       .def_property_readonly("type", [](PyValue &self) {
3050         return PyType(self.getParentOperation()->getContext(),
3051                       mlirValueGetType(self.get()));
3052       });
3053   PyBlockArgument::bind(m);
3054   PyOpResult::bind(m);
3055 
3056   //----------------------------------------------------------------------------
3057   // Mapping of SymbolTable.
3058   //----------------------------------------------------------------------------
3059   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3060       .def(py::init<PyOperationBase &>())
3061       .def("__getitem__", &PySymbolTable::dunderGetItem)
3062       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3063       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3064       .def("__delitem__", &PySymbolTable::dunderDel)
3065       .def("__contains__",
3066            [](PySymbolTable &table, const std::string &name) {
3067              return !mlirOperationIsNull(mlirSymbolTableLookup(
3068                  table, mlirStringRefCreate(name.data(), name.length())));
3069            })
3070       // Static helpers.
3071       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3072                   py::arg("symbol"), py::arg("name"))
3073       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3074                   py::arg("symbol"))
3075       .def_static("get_visibility", &PySymbolTable::getVisibility,
3076                   py::arg("symbol"))
3077       .def_static("set_visibility", &PySymbolTable::setVisibility,
3078                   py::arg("symbol"), py::arg("visibility"))
3079       .def_static("replace_all_symbol_uses",
3080                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3081                   py::arg("new_symbol"), py::arg("from_op"))
3082       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3083                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3084                   py::arg("callback"));
3085 
3086   // Container bindings.
3087   PyBlockArgumentList::bind(m);
3088   PyBlockIterator::bind(m);
3089   PyBlockList::bind(m);
3090   PyOperationIterator::bind(m);
3091   PyOperationList::bind(m);
3092   PyOpAttributeMap::bind(m);
3093   PyOpOperandList::bind(m);
3094   PyOpResultList::bind(m);
3095   PyRegionIterator::bind(m);
3096   PyRegionList::bind(m);
3097 
3098   // Debug bindings.
3099   PyGlobalDebugFlag::bind(m);
3100 }
3101