1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 #include <utility>
25 
26 namespace py = pybind11;
27 using namespace mlir;
28 using namespace mlir::python;
29 
30 using llvm::SmallVector;
31 using llvm::StringRef;
32 using llvm::Twine;
33 
34 //------------------------------------------------------------------------------
35 // Docstrings (trivial, non-duplicated docstrings are included inline).
36 //------------------------------------------------------------------------------
37 
38 static const char kContextParseTypeDocstring[] =
39     R"(Parses the assembly form of a type.
40 
41 Returns a Type object or raises a ValueError if the type cannot be parsed.
42 
43 See also: https://mlir.llvm.org/docs/LangRef/#type-system
44 )";
45 
46 static const char kContextGetCallSiteLocationDocstring[] =
47     R"(Gets a Location representing a caller and callsite)";
48 
49 static const char kContextGetFileLocationDocstring[] =
50     R"(Gets a Location representing a file, line and column)";
51 
52 static const char kContextGetFusedLocationDocstring[] =
53     R"(Gets a Location representing a fused location with optional metadata)";
54 
55 static const char kContextGetNameLocationDocString[] =
56     R"(Gets a Location representing a named location with optional child location)";
57 
58 static const char kModuleParseDocstring[] =
59     R"(Parses a module's assembly format from a string.
60 
61 Returns a new MlirModule or raises a ValueError if the parsing fails.
62 
63 See also: https://mlir.llvm.org/docs/LangRef/
64 )";
65 
66 static const char kOperationCreateDocstring[] =
67     R"(Creates a new operation.
68 
69 Args:
70   name: Operation name (e.g. "dialect.operation").
71   results: Sequence of Type representing op result types.
72   attributes: Dict of str:Attribute.
73   successors: List of Block for the operation's successors.
74   regions: Number of regions to create.
75   location: A Location object (defaults to resolve from context manager).
76   ip: An InsertionPoint (defaults to resolve from context manager or set to
77     False to disable insertion, even with an insertion point set in the
78     context manager).
79 Returns:
80   A new "detached" Operation object. Detached operations can be added
81   to blocks, which causes them to become "attached."
82 )";
83 
84 static const char kOperationPrintDocstring[] =
85     R"(Prints the assembly form of the operation to a file like object.
86 
87 Args:
88   file: The file like object to write to. Defaults to sys.stdout.
89   binary: Whether to write bytes (True) or str (False). Defaults to False.
90   large_elements_limit: Whether to elide elements attributes above this
91     number of elements. Defaults to None (no limit).
92   enable_debug_info: Whether to print debug/location information. Defaults
93     to False.
94   pretty_debug_info: Whether to format debug information for easier reading
95     by a human (warning: the result is unparseable).
96   print_generic_op_form: Whether to print the generic assembly forms of all
97     ops. Defaults to False.
98   use_local_Scope: Whether to print in a way that is more optimized for
99     multi-threaded access but may not be consistent with how the overall
100     module prints.
101   assume_verified: By default, if not printing generic form, the verifier
102     will be run and if it fails, generic form will be printed with a comment
103     about failed verification. While a reasonable default for interactive use,
104     for systematic use, it is often better for the caller to verify explicitly
105     and report failures in a more robust fashion. Set this to True if doing this
106     in order to avoid running a redundant verification. If the IR is actually
107     invalid, behavior is undefined.
108 )";
109 
110 static const char kOperationGetAsmDocstring[] =
111     R"(Gets the assembly form of the operation with all options available.
112 
113 Args:
114   binary: Whether to return a bytes (True) or str (False) object. Defaults to
115     False.
116   ... others ...: See the print() method for common keyword arguments for
117     configuring the printout.
118 Returns:
119   Either a bytes or str object, depending on the setting of the 'binary'
120   argument.
121 )";
122 
123 static const char kOperationStrDunderDocstring[] =
124     R"(Gets the assembly form of the operation with default options.
125 
126 If more advanced control over the assembly formatting or I/O options is needed,
127 use the dedicated print or get_asm method, which supports keyword arguments to
128 customize behavior.
129 )";
130 
131 static const char kDumpDocstring[] =
132     R"(Dumps a debug representation of the object to stderr.)";
133 
134 static const char kAppendBlockDocstring[] =
135     R"(Appends a new block, with argument types as positional args.
136 
137 Returns:
138   The created block.
139 )";
140 
141 static const char kValueDunderStrDocstring[] =
142     R"(Returns the string form of the value.
143 
144 If the value is a block argument, this is the assembly form of its type and the
145 position in the argument list. If the value is an operation result, this is
146 equivalent to printing the operation that produced it.
147 )";
148 
149 //------------------------------------------------------------------------------
150 // Utilities.
151 //------------------------------------------------------------------------------
152 
153 /// Helper for creating an @classmethod.
154 template <class Func, typename... Args>
155 py::object classmethod(Func f, Args... args) {
156   py::object cf = py::cpp_function(f, args...);
157   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
158 }
159 
160 static py::object
161 createCustomDialectWrapper(const std::string &dialectNamespace,
162                            py::object dialectDescriptor) {
163   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
164   if (!dialectClass) {
165     // Use the base class.
166     return py::cast(PyDialect(std::move(dialectDescriptor)));
167   }
168 
169   // Create the custom implementation.
170   return (*dialectClass)(std::move(dialectDescriptor));
171 }
172 
173 static MlirStringRef toMlirStringRef(const std::string &s) {
174   return mlirStringRefCreate(s.data(), s.size());
175 }
176 
177 /// Wrapper for the global LLVM debugging flag.
178 struct PyGlobalDebugFlag {
179   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
180 
181   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
182 
183   static void bind(py::module &m) {
184     // Debug flags.
185     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
186         .def_property_static("flag", &PyGlobalDebugFlag::get,
187                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
188   }
189 };
190 
191 //------------------------------------------------------------------------------
192 // Collections.
193 //------------------------------------------------------------------------------
194 
195 namespace {
196 
197 class PyRegionIterator {
198 public:
199   PyRegionIterator(PyOperationRef operation)
200       : operation(std::move(operation)) {}
201 
202   PyRegionIterator &dunderIter() { return *this; }
203 
204   PyRegion dunderNext() {
205     operation->checkValid();
206     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
207       throw py::stop_iteration();
208     }
209     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
210     return PyRegion(operation, region);
211   }
212 
213   static void bind(py::module &m) {
214     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
215         .def("__iter__", &PyRegionIterator::dunderIter)
216         .def("__next__", &PyRegionIterator::dunderNext);
217   }
218 
219 private:
220   PyOperationRef operation;
221   int nextIndex = 0;
222 };
223 
224 /// Regions of an op are fixed length and indexed numerically so are represented
225 /// with a sequence-like container.
226 class PyRegionList {
227 public:
228   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
229 
230   intptr_t dunderLen() {
231     operation->checkValid();
232     return mlirOperationGetNumRegions(operation->get());
233   }
234 
235   PyRegion dunderGetItem(intptr_t index) {
236     // dunderLen checks validity.
237     if (index < 0 || index >= dunderLen()) {
238       throw SetPyError(PyExc_IndexError,
239                        "attempt to access out of bounds region");
240     }
241     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
242     return PyRegion(operation, region);
243   }
244 
245   static void bind(py::module &m) {
246     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
247         .def("__len__", &PyRegionList::dunderLen)
248         .def("__getitem__", &PyRegionList::dunderGetItem);
249   }
250 
251 private:
252   PyOperationRef operation;
253 };
254 
255 class PyBlockIterator {
256 public:
257   PyBlockIterator(PyOperationRef operation, MlirBlock next)
258       : operation(std::move(operation)), next(next) {}
259 
260   PyBlockIterator &dunderIter() { return *this; }
261 
262   PyBlock dunderNext() {
263     operation->checkValid();
264     if (mlirBlockIsNull(next)) {
265       throw py::stop_iteration();
266     }
267 
268     PyBlock returnBlock(operation, next);
269     next = mlirBlockGetNextInRegion(next);
270     return returnBlock;
271   }
272 
273   static void bind(py::module &m) {
274     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
275         .def("__iter__", &PyBlockIterator::dunderIter)
276         .def("__next__", &PyBlockIterator::dunderNext);
277   }
278 
279 private:
280   PyOperationRef operation;
281   MlirBlock next;
282 };
283 
284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
285 /// we present them as a more full-featured list-like container but optimize
286 /// it for forward iteration. Blocks are always owned by a region.
287 class PyBlockList {
288 public:
289   PyBlockList(PyOperationRef operation, MlirRegion region)
290       : operation(std::move(operation)), region(region) {}
291 
292   PyBlockIterator dunderIter() {
293     operation->checkValid();
294     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
295   }
296 
297   intptr_t dunderLen() {
298     operation->checkValid();
299     intptr_t count = 0;
300     MlirBlock block = mlirRegionGetFirstBlock(region);
301     while (!mlirBlockIsNull(block)) {
302       count += 1;
303       block = mlirBlockGetNextInRegion(block);
304     }
305     return count;
306   }
307 
308   PyBlock dunderGetItem(intptr_t index) {
309     operation->checkValid();
310     if (index < 0) {
311       throw SetPyError(PyExc_IndexError,
312                        "attempt to access out of bounds block");
313     }
314     MlirBlock block = mlirRegionGetFirstBlock(region);
315     while (!mlirBlockIsNull(block)) {
316       if (index == 0) {
317         return PyBlock(operation, block);
318       }
319       block = mlirBlockGetNextInRegion(block);
320       index -= 1;
321     }
322     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
323   }
324 
325   PyBlock appendBlock(const py::args &pyArgTypes) {
326     operation->checkValid();
327     llvm::SmallVector<MlirType, 4> argTypes;
328     llvm::SmallVector<MlirLocation, 4> argLocs;
329     argTypes.reserve(pyArgTypes.size());
330     argLocs.reserve(pyArgTypes.size());
331     for (auto &pyArg : pyArgTypes) {
332       argTypes.push_back(pyArg.cast<PyType &>());
333       // TODO: Pass in a proper location here.
334       argLocs.push_back(
335           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
336     }
337 
338     MlirBlock block =
339         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
340     mlirRegionAppendOwnedBlock(region, block);
341     return PyBlock(operation, block);
342   }
343 
344   static void bind(py::module &m) {
345     py::class_<PyBlockList>(m, "BlockList", py::module_local())
346         .def("__getitem__", &PyBlockList::dunderGetItem)
347         .def("__iter__", &PyBlockList::dunderIter)
348         .def("__len__", &PyBlockList::dunderLen)
349         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
350   }
351 
352 private:
353   PyOperationRef operation;
354   MlirRegion region;
355 };
356 
357 class PyOperationIterator {
358 public:
359   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
360       : parentOperation(std::move(parentOperation)), next(next) {}
361 
362   PyOperationIterator &dunderIter() { return *this; }
363 
364   py::object dunderNext() {
365     parentOperation->checkValid();
366     if (mlirOperationIsNull(next)) {
367       throw py::stop_iteration();
368     }
369 
370     PyOperationRef returnOperation =
371         PyOperation::forOperation(parentOperation->getContext(), next);
372     next = mlirOperationGetNextInBlock(next);
373     return returnOperation->createOpView();
374   }
375 
376   static void bind(py::module &m) {
377     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
378         .def("__iter__", &PyOperationIterator::dunderIter)
379         .def("__next__", &PyOperationIterator::dunderNext);
380   }
381 
382 private:
383   PyOperationRef parentOperation;
384   MlirOperation next;
385 };
386 
387 /// Operations are exposed by the C-API as a forward-only linked list. In
388 /// Python, we present them as a more full-featured list-like container but
389 /// optimize it for forward iteration. Iterable operations are always owned
390 /// by a block.
391 class PyOperationList {
392 public:
393   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
394       : parentOperation(std::move(parentOperation)), block(block) {}
395 
396   PyOperationIterator dunderIter() {
397     parentOperation->checkValid();
398     return PyOperationIterator(parentOperation,
399                                mlirBlockGetFirstOperation(block));
400   }
401 
402   intptr_t dunderLen() {
403     parentOperation->checkValid();
404     intptr_t count = 0;
405     MlirOperation childOp = mlirBlockGetFirstOperation(block);
406     while (!mlirOperationIsNull(childOp)) {
407       count += 1;
408       childOp = mlirOperationGetNextInBlock(childOp);
409     }
410     return count;
411   }
412 
413   py::object dunderGetItem(intptr_t index) {
414     parentOperation->checkValid();
415     if (index < 0) {
416       throw SetPyError(PyExc_IndexError,
417                        "attempt to access out of bounds operation");
418     }
419     MlirOperation childOp = mlirBlockGetFirstOperation(block);
420     while (!mlirOperationIsNull(childOp)) {
421       if (index == 0) {
422         return PyOperation::forOperation(parentOperation->getContext(), childOp)
423             ->createOpView();
424       }
425       childOp = mlirOperationGetNextInBlock(childOp);
426       index -= 1;
427     }
428     throw SetPyError(PyExc_IndexError,
429                      "attempt to access out of bounds operation");
430   }
431 
432   static void bind(py::module &m) {
433     py::class_<PyOperationList>(m, "OperationList", py::module_local())
434         .def("__getitem__", &PyOperationList::dunderGetItem)
435         .def("__iter__", &PyOperationList::dunderIter)
436         .def("__len__", &PyOperationList::dunderLen);
437   }
438 
439 private:
440   PyOperationRef parentOperation;
441   MlirBlock block;
442 };
443 
444 } // namespace
445 
446 //------------------------------------------------------------------------------
447 // PyMlirContext
448 //------------------------------------------------------------------------------
449 
450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
451   py::gil_scoped_acquire acquire;
452   auto &liveContexts = getLiveContexts();
453   liveContexts[context.ptr] = this;
454 }
455 
456 PyMlirContext::~PyMlirContext() {
457   // Note that the only public way to construct an instance is via the
458   // forContext method, which always puts the associated handle into
459   // liveContexts.
460   py::gil_scoped_acquire acquire;
461   getLiveContexts().erase(context.ptr);
462   mlirContextDestroy(context);
463 }
464 
465 py::object PyMlirContext::getCapsule() {
466   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
467 }
468 
469 py::object PyMlirContext::createFromCapsule(py::object capsule) {
470   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
471   if (mlirContextIsNull(rawContext))
472     throw py::error_already_set();
473   return forContext(rawContext).releaseObject();
474 }
475 
476 PyMlirContext *PyMlirContext::createNewContextForInit() {
477   MlirContext context = mlirContextCreate();
478   mlirRegisterAllDialects(context);
479   return new PyMlirContext(context);
480 }
481 
482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
483   py::gil_scoped_acquire acquire;
484   auto &liveContexts = getLiveContexts();
485   auto it = liveContexts.find(context.ptr);
486   if (it == liveContexts.end()) {
487     // Create.
488     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
489     py::object pyRef = py::cast(unownedContextWrapper);
490     assert(pyRef && "cast to py::object failed");
491     liveContexts[context.ptr] = unownedContextWrapper;
492     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
493   }
494   // Use existing.
495   py::object pyRef = py::cast(it->second);
496   return PyMlirContextRef(it->second, std::move(pyRef));
497 }
498 
499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
500   static LiveContextMap liveContexts;
501   return liveContexts;
502 }
503 
504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
505 
506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
507 
508 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
509 
510 pybind11::object PyMlirContext::contextEnter() {
511   return PyThreadContextEntry::pushContext(*this);
512 }
513 
514 void PyMlirContext::contextExit(const pybind11::object &excType,
515                                 const pybind11::object &excVal,
516                                 const pybind11::object &excTb) {
517   PyThreadContextEntry::popContext(*this);
518 }
519 
520 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
521   // Note that ownership is transferred to the delete callback below by way of
522   // an explicit inc_ref (borrow).
523   PyDiagnosticHandler *pyHandler =
524       new PyDiagnosticHandler(get(), std::move(callback));
525   py::object pyHandlerObject =
526       py::cast(pyHandler, py::return_value_policy::take_ownership);
527   pyHandlerObject.inc_ref();
528 
529   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
530   // guaranteed to be known to pybind.
531   auto handlerCallback =
532       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
533     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
534     py::object pyDiagnosticObject =
535         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
536 
537     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
538     bool result = false;
539     {
540       // Since this can be called from arbitrary C++ contexts, always get the
541       // gil.
542       py::gil_scoped_acquire gil;
543       try {
544         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
545       } catch (std::exception &e) {
546         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
547                 e.what());
548         pyHandler->hadError = true;
549       }
550     }
551 
552     pyDiagnostic->invalidate();
553     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
554   };
555   auto deleteCallback = +[](void *userData) {
556     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
557     assert(pyHandler->registeredID && "handler is not registered");
558     pyHandler->registeredID.reset();
559 
560     // Decrement reference, balancing the inc_ref() above.
561     py::object pyHandlerObject =
562         py::cast(pyHandler, py::return_value_policy::reference);
563     pyHandlerObject.dec_ref();
564   };
565 
566   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
567       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
568   return pyHandlerObject;
569 }
570 
571 PyMlirContext &DefaultingPyMlirContext::resolve() {
572   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
573   if (!context) {
574     throw SetPyError(
575         PyExc_RuntimeError,
576         "An MLIR function requires a Context but none was provided in the call "
577         "or from the surrounding environment. Either pass to the function with "
578         "a 'context=' argument or establish a default using 'with Context():'");
579   }
580   return *context;
581 }
582 
583 //------------------------------------------------------------------------------
584 // PyThreadContextEntry management
585 //------------------------------------------------------------------------------
586 
587 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
588   static thread_local std::vector<PyThreadContextEntry> stack;
589   return stack;
590 }
591 
592 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
593   auto &stack = getStack();
594   if (stack.empty())
595     return nullptr;
596   return &stack.back();
597 }
598 
599 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
600                                 py::object insertionPoint,
601                                 py::object location) {
602   auto &stack = getStack();
603   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
604                      std::move(location));
605   // If the new stack has more than one entry and the context of the new top
606   // entry matches the previous, copy the insertionPoint and location from the
607   // previous entry if missing from the new top entry.
608   if (stack.size() > 1) {
609     auto &prev = *(stack.rbegin() + 1);
610     auto &current = stack.back();
611     if (current.context.is(prev.context)) {
612       // Default non-context objects from the previous entry.
613       if (!current.insertionPoint)
614         current.insertionPoint = prev.insertionPoint;
615       if (!current.location)
616         current.location = prev.location;
617     }
618   }
619 }
620 
621 PyMlirContext *PyThreadContextEntry::getContext() {
622   if (!context)
623     return nullptr;
624   return py::cast<PyMlirContext *>(context);
625 }
626 
627 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
628   if (!insertionPoint)
629     return nullptr;
630   return py::cast<PyInsertionPoint *>(insertionPoint);
631 }
632 
633 PyLocation *PyThreadContextEntry::getLocation() {
634   if (!location)
635     return nullptr;
636   return py::cast<PyLocation *>(location);
637 }
638 
639 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
640   auto *tos = getTopOfStack();
641   return tos ? tos->getContext() : nullptr;
642 }
643 
644 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
645   auto *tos = getTopOfStack();
646   return tos ? tos->getInsertionPoint() : nullptr;
647 }
648 
649 PyLocation *PyThreadContextEntry::getDefaultLocation() {
650   auto *tos = getTopOfStack();
651   return tos ? tos->getLocation() : nullptr;
652 }
653 
654 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
655   py::object contextObj = py::cast(context);
656   push(FrameKind::Context, /*context=*/contextObj,
657        /*insertionPoint=*/py::object(),
658        /*location=*/py::object());
659   return contextObj;
660 }
661 
662 void PyThreadContextEntry::popContext(PyMlirContext &context) {
663   auto &stack = getStack();
664   if (stack.empty())
665     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
666   auto &tos = stack.back();
667   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
668     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
669   stack.pop_back();
670 }
671 
672 py::object
673 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
674   py::object contextObj =
675       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
676   py::object insertionPointObj = py::cast(insertionPoint);
677   push(FrameKind::InsertionPoint,
678        /*context=*/contextObj,
679        /*insertionPoint=*/insertionPointObj,
680        /*location=*/py::object());
681   return insertionPointObj;
682 }
683 
684 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
685   auto &stack = getStack();
686   if (stack.empty())
687     throw SetPyError(PyExc_RuntimeError,
688                      "Unbalanced InsertionPoint enter/exit");
689   auto &tos = stack.back();
690   if (tos.frameKind != FrameKind::InsertionPoint &&
691       tos.getInsertionPoint() != &insertionPoint)
692     throw SetPyError(PyExc_RuntimeError,
693                      "Unbalanced InsertionPoint enter/exit");
694   stack.pop_back();
695 }
696 
697 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
698   py::object contextObj = location.getContext().getObject();
699   py::object locationObj = py::cast(location);
700   push(FrameKind::Location, /*context=*/contextObj,
701        /*insertionPoint=*/py::object(),
702        /*location=*/locationObj);
703   return locationObj;
704 }
705 
706 void PyThreadContextEntry::popLocation(PyLocation &location) {
707   auto &stack = getStack();
708   if (stack.empty())
709     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
710   auto &tos = stack.back();
711   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
712     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
713   stack.pop_back();
714 }
715 
716 //------------------------------------------------------------------------------
717 // PyDiagnostic*
718 //------------------------------------------------------------------------------
719 
720 void PyDiagnostic::invalidate() {
721   valid = false;
722   if (materializedNotes) {
723     for (auto &noteObject : *materializedNotes) {
724       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
725       note->invalidate();
726     }
727   }
728 }
729 
730 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
731                                          py::object callback)
732     : context(context), callback(std::move(callback)) {}
733 
734 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
735 
736 void PyDiagnosticHandler::detach() {
737   if (!registeredID)
738     return;
739   MlirDiagnosticHandlerID localID = *registeredID;
740   mlirContextDetachDiagnosticHandler(context, localID);
741   assert(!registeredID && "should have unregistered");
742   // Not strictly necessary but keeps stale pointers from being around to cause
743   // issues.
744   context = {nullptr};
745 }
746 
747 void PyDiagnostic::checkValid() {
748   if (!valid) {
749     throw std::invalid_argument(
750         "Diagnostic is invalid (used outside of callback)");
751   }
752 }
753 
754 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
755   checkValid();
756   return mlirDiagnosticGetSeverity(diagnostic);
757 }
758 
759 PyLocation PyDiagnostic::getLocation() {
760   checkValid();
761   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
762   MlirContext context = mlirLocationGetContext(loc);
763   return PyLocation(PyMlirContext::forContext(context), loc);
764 }
765 
766 py::str PyDiagnostic::getMessage() {
767   checkValid();
768   py::object fileObject = py::module::import("io").attr("StringIO")();
769   PyFileAccumulator accum(fileObject, /*binary=*/false);
770   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
771   return fileObject.attr("getvalue")();
772 }
773 
774 py::tuple PyDiagnostic::getNotes() {
775   checkValid();
776   if (materializedNotes)
777     return *materializedNotes;
778   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
779   materializedNotes = py::tuple(numNotes);
780   for (intptr_t i = 0; i < numNotes; ++i) {
781     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
782     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
783     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
784   }
785   return *materializedNotes;
786 }
787 
788 //------------------------------------------------------------------------------
789 // PyDialect, PyDialectDescriptor, PyDialects
790 //------------------------------------------------------------------------------
791 
792 MlirDialect PyDialects::getDialectForKey(const std::string &key,
793                                          bool attrError) {
794   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
795                                                     {key.data(), key.size()});
796   if (mlirDialectIsNull(dialect)) {
797     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
798                      Twine("Dialect '") + key + "' not found");
799   }
800   return dialect;
801 }
802 
803 //------------------------------------------------------------------------------
804 // PyLocation
805 //------------------------------------------------------------------------------
806 
807 py::object PyLocation::getCapsule() {
808   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
809 }
810 
811 PyLocation PyLocation::createFromCapsule(py::object capsule) {
812   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
813   if (mlirLocationIsNull(rawLoc))
814     throw py::error_already_set();
815   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
816                     rawLoc);
817 }
818 
819 py::object PyLocation::contextEnter() {
820   return PyThreadContextEntry::pushLocation(*this);
821 }
822 
823 void PyLocation::contextExit(const pybind11::object &excType,
824                              const pybind11::object &excVal,
825                              const pybind11::object &excTb) {
826   PyThreadContextEntry::popLocation(*this);
827 }
828 
829 PyLocation &DefaultingPyLocation::resolve() {
830   auto *location = PyThreadContextEntry::getDefaultLocation();
831   if (!location) {
832     throw SetPyError(
833         PyExc_RuntimeError,
834         "An MLIR function requires a Location but none was provided in the "
835         "call or from the surrounding environment. Either pass to the function "
836         "with a 'loc=' argument or establish a default using 'with loc:'");
837   }
838   return *location;
839 }
840 
841 //------------------------------------------------------------------------------
842 // PyModule
843 //------------------------------------------------------------------------------
844 
845 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
846     : BaseContextObject(std::move(contextRef)), module(module) {}
847 
848 PyModule::~PyModule() {
849   py::gil_scoped_acquire acquire;
850   auto &liveModules = getContext()->liveModules;
851   assert(liveModules.count(module.ptr) == 1 &&
852          "destroying module not in live map");
853   liveModules.erase(module.ptr);
854   mlirModuleDestroy(module);
855 }
856 
857 PyModuleRef PyModule::forModule(MlirModule module) {
858   MlirContext context = mlirModuleGetContext(module);
859   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
860 
861   py::gil_scoped_acquire acquire;
862   auto &liveModules = contextRef->liveModules;
863   auto it = liveModules.find(module.ptr);
864   if (it == liveModules.end()) {
865     // Create.
866     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
867     // Note that the default return value policy on cast is automatic_reference,
868     // which does not take ownership (delete will not be called).
869     // Just be explicit.
870     py::object pyRef =
871         py::cast(unownedModule, py::return_value_policy::take_ownership);
872     unownedModule->handle = pyRef;
873     liveModules[module.ptr] =
874         std::make_pair(unownedModule->handle, unownedModule);
875     return PyModuleRef(unownedModule, std::move(pyRef));
876   }
877   // Use existing.
878   PyModule *existing = it->second.second;
879   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
880   return PyModuleRef(existing, std::move(pyRef));
881 }
882 
883 py::object PyModule::createFromCapsule(py::object capsule) {
884   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
885   if (mlirModuleIsNull(rawModule))
886     throw py::error_already_set();
887   return forModule(rawModule).releaseObject();
888 }
889 
890 py::object PyModule::getCapsule() {
891   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
892 }
893 
894 //------------------------------------------------------------------------------
895 // PyOperation
896 //------------------------------------------------------------------------------
897 
898 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
899     : BaseContextObject(std::move(contextRef)), operation(operation) {}
900 
901 PyOperation::~PyOperation() {
902   // If the operation has already been invalidated there is nothing to do.
903   if (!valid)
904     return;
905   auto &liveOperations = getContext()->liveOperations;
906   assert(liveOperations.count(operation.ptr) == 1 &&
907          "destroying operation not in live map");
908   liveOperations.erase(operation.ptr);
909   if (!isAttached()) {
910     mlirOperationDestroy(operation);
911   }
912 }
913 
914 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
915                                            MlirOperation operation,
916                                            py::object parentKeepAlive) {
917   auto &liveOperations = contextRef->liveOperations;
918   // Create.
919   PyOperation *unownedOperation =
920       new PyOperation(std::move(contextRef), operation);
921   // Note that the default return value policy on cast is automatic_reference,
922   // which does not take ownership (delete will not be called).
923   // Just be explicit.
924   py::object pyRef =
925       py::cast(unownedOperation, py::return_value_policy::take_ownership);
926   unownedOperation->handle = pyRef;
927   if (parentKeepAlive) {
928     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
929   }
930   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
931   return PyOperationRef(unownedOperation, std::move(pyRef));
932 }
933 
934 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
935                                          MlirOperation operation,
936                                          py::object parentKeepAlive) {
937   auto &liveOperations = contextRef->liveOperations;
938   auto it = liveOperations.find(operation.ptr);
939   if (it == liveOperations.end()) {
940     // Create.
941     return createInstance(std::move(contextRef), operation,
942                           std::move(parentKeepAlive));
943   }
944   // Use existing.
945   PyOperation *existing = it->second.second;
946   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
947   return PyOperationRef(existing, std::move(pyRef));
948 }
949 
950 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
951                                            MlirOperation operation,
952                                            py::object parentKeepAlive) {
953   auto &liveOperations = contextRef->liveOperations;
954   assert(liveOperations.count(operation.ptr) == 0 &&
955          "cannot create detached operation that already exists");
956   (void)liveOperations;
957 
958   PyOperationRef created = createInstance(std::move(contextRef), operation,
959                                           std::move(parentKeepAlive));
960   created->attached = false;
961   return created;
962 }
963 
964 void PyOperation::checkValid() const {
965   if (!valid) {
966     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
967   }
968 }
969 
970 void PyOperationBase::print(py::object fileObject, bool binary,
971                             llvm::Optional<int64_t> largeElementsLimit,
972                             bool enableDebugInfo, bool prettyDebugInfo,
973                             bool printGenericOpForm, bool useLocalScope,
974                             bool assumeVerified) {
975   PyOperation &operation = getOperation();
976   operation.checkValid();
977   if (fileObject.is_none())
978     fileObject = py::module::import("sys").attr("stdout");
979 
980   if (!assumeVerified && !printGenericOpForm &&
981       !mlirOperationVerify(operation)) {
982     std::string message("// Verification failed, printing generic form\n");
983     if (binary) {
984       fileObject.attr("write")(py::bytes(message));
985     } else {
986       fileObject.attr("write")(py::str(message));
987     }
988     printGenericOpForm = true;
989   }
990 
991   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
992   if (largeElementsLimit)
993     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
994   if (enableDebugInfo)
995     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
996   if (printGenericOpForm)
997     mlirOpPrintingFlagsPrintGenericOpForm(flags);
998 
999   PyFileAccumulator accum(fileObject, binary);
1000   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1001                               accum.getUserData());
1002   mlirOpPrintingFlagsDestroy(flags);
1003 }
1004 
1005 py::object PyOperationBase::getAsm(bool binary,
1006                                    llvm::Optional<int64_t> largeElementsLimit,
1007                                    bool enableDebugInfo, bool prettyDebugInfo,
1008                                    bool printGenericOpForm, bool useLocalScope,
1009                                    bool assumeVerified) {
1010   py::object fileObject;
1011   if (binary) {
1012     fileObject = py::module::import("io").attr("BytesIO")();
1013   } else {
1014     fileObject = py::module::import("io").attr("StringIO")();
1015   }
1016   print(fileObject, /*binary=*/binary,
1017         /*largeElementsLimit=*/largeElementsLimit,
1018         /*enableDebugInfo=*/enableDebugInfo,
1019         /*prettyDebugInfo=*/prettyDebugInfo,
1020         /*printGenericOpForm=*/printGenericOpForm,
1021         /*useLocalScope=*/useLocalScope,
1022         /*assumeVerified=*/assumeVerified);
1023 
1024   return fileObject.attr("getvalue")();
1025 }
1026 
1027 void PyOperationBase::moveAfter(PyOperationBase &other) {
1028   PyOperation &operation = getOperation();
1029   PyOperation &otherOp = other.getOperation();
1030   operation.checkValid();
1031   otherOp.checkValid();
1032   mlirOperationMoveAfter(operation, otherOp);
1033   operation.parentKeepAlive = otherOp.parentKeepAlive;
1034 }
1035 
1036 void PyOperationBase::moveBefore(PyOperationBase &other) {
1037   PyOperation &operation = getOperation();
1038   PyOperation &otherOp = other.getOperation();
1039   operation.checkValid();
1040   otherOp.checkValid();
1041   mlirOperationMoveBefore(operation, otherOp);
1042   operation.parentKeepAlive = otherOp.parentKeepAlive;
1043 }
1044 
1045 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1046   checkValid();
1047   if (!isAttached())
1048     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1049   MlirOperation operation = mlirOperationGetParentOperation(get());
1050   if (mlirOperationIsNull(operation))
1051     return {};
1052   return PyOperation::forOperation(getContext(), operation);
1053 }
1054 
1055 PyBlock PyOperation::getBlock() {
1056   checkValid();
1057   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1058   MlirBlock block = mlirOperationGetBlock(get());
1059   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1060   assert(parentOperation && "Operation has no parent");
1061   return PyBlock{std::move(*parentOperation), block};
1062 }
1063 
1064 py::object PyOperation::getCapsule() {
1065   checkValid();
1066   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1067 }
1068 
1069 py::object PyOperation::createFromCapsule(py::object capsule) {
1070   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1071   if (mlirOperationIsNull(rawOperation))
1072     throw py::error_already_set();
1073   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1074   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1075       .releaseObject();
1076 }
1077 
1078 py::object PyOperation::create(
1079     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1080     llvm::Optional<std::vector<PyValue *>> operands,
1081     llvm::Optional<py::dict> attributes,
1082     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1083     DefaultingPyLocation location, const py::object &maybeIp) {
1084   llvm::SmallVector<MlirValue, 4> mlirOperands;
1085   llvm::SmallVector<MlirType, 4> mlirResults;
1086   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1087   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1088 
1089   // General parameter validation.
1090   if (regions < 0)
1091     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1092 
1093   // Unpack/validate operands.
1094   if (operands) {
1095     mlirOperands.reserve(operands->size());
1096     for (PyValue *operand : *operands) {
1097       if (!operand)
1098         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1099       mlirOperands.push_back(operand->get());
1100     }
1101   }
1102 
1103   // Unpack/validate results.
1104   if (results) {
1105     mlirResults.reserve(results->size());
1106     for (PyType *result : *results) {
1107       // TODO: Verify result type originate from the same context.
1108       if (!result)
1109         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1110       mlirResults.push_back(*result);
1111     }
1112   }
1113   // Unpack/validate attributes.
1114   if (attributes) {
1115     mlirAttributes.reserve(attributes->size());
1116     for (auto &it : *attributes) {
1117       std::string key;
1118       try {
1119         key = it.first.cast<std::string>();
1120       } catch (py::cast_error &err) {
1121         std::string msg = "Invalid attribute key (not a string) when "
1122                           "attempting to create the operation \"" +
1123                           name + "\" (" + err.what() + ")";
1124         throw py::cast_error(msg);
1125       }
1126       try {
1127         auto &attribute = it.second.cast<PyAttribute &>();
1128         // TODO: Verify attribute originates from the same context.
1129         mlirAttributes.emplace_back(std::move(key), attribute);
1130       } catch (py::reference_cast_error &) {
1131         // This exception seems thrown when the value is "None".
1132         std::string msg =
1133             "Found an invalid (`None`?) attribute value for the key \"" + key +
1134             "\" when attempting to create the operation \"" + name + "\"";
1135         throw py::cast_error(msg);
1136       } catch (py::cast_error &err) {
1137         std::string msg = "Invalid attribute value for the key \"" + key +
1138                           "\" when attempting to create the operation \"" +
1139                           name + "\" (" + err.what() + ")";
1140         throw py::cast_error(msg);
1141       }
1142     }
1143   }
1144   // Unpack/validate successors.
1145   if (successors) {
1146     mlirSuccessors.reserve(successors->size());
1147     for (auto *successor : *successors) {
1148       // TODO: Verify successor originate from the same context.
1149       if (!successor)
1150         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1151       mlirSuccessors.push_back(successor->get());
1152     }
1153   }
1154 
1155   // Apply unpacked/validated to the operation state. Beyond this
1156   // point, exceptions cannot be thrown or else the state will leak.
1157   MlirOperationState state =
1158       mlirOperationStateGet(toMlirStringRef(name), location);
1159   if (!mlirOperands.empty())
1160     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1161                                   mlirOperands.data());
1162   if (!mlirResults.empty())
1163     mlirOperationStateAddResults(&state, mlirResults.size(),
1164                                  mlirResults.data());
1165   if (!mlirAttributes.empty()) {
1166     // Note that the attribute names directly reference bytes in
1167     // mlirAttributes, so that vector must not be changed from here
1168     // on.
1169     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1170     mlirNamedAttributes.reserve(mlirAttributes.size());
1171     for (auto &it : mlirAttributes)
1172       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1173           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1174                             toMlirStringRef(it.first)),
1175           it.second));
1176     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1177                                     mlirNamedAttributes.data());
1178   }
1179   if (!mlirSuccessors.empty())
1180     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1181                                     mlirSuccessors.data());
1182   if (regions) {
1183     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1184     mlirRegions.resize(regions);
1185     for (int i = 0; i < regions; ++i)
1186       mlirRegions[i] = mlirRegionCreate();
1187     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1188                                       mlirRegions.data());
1189   }
1190 
1191   // Construct the operation.
1192   MlirOperation operation = mlirOperationCreate(&state);
1193   PyOperationRef created =
1194       PyOperation::createDetached(location->getContext(), operation);
1195 
1196   // InsertPoint active?
1197   if (!maybeIp.is(py::cast(false))) {
1198     PyInsertionPoint *ip;
1199     if (maybeIp.is_none()) {
1200       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1201     } else {
1202       ip = py::cast<PyInsertionPoint *>(maybeIp);
1203     }
1204     if (ip)
1205       ip->insert(*created.get());
1206   }
1207 
1208   return created->createOpView();
1209 }
1210 
1211 py::object PyOperation::createOpView() {
1212   checkValid();
1213   MlirIdentifier ident = mlirOperationGetName(get());
1214   MlirStringRef identStr = mlirIdentifierStr(ident);
1215   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1216       StringRef(identStr.data, identStr.length));
1217   if (opViewClass)
1218     return (*opViewClass)(getRef().getObject());
1219   return py::cast(PyOpView(getRef().getObject()));
1220 }
1221 
1222 void PyOperation::erase() {
1223   checkValid();
1224   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1225   // Python reference to a child operation is live. All children should also
1226   // have their `valid` bit set to false.
1227   auto &liveOperations = getContext()->liveOperations;
1228   if (liveOperations.count(operation.ptr))
1229     liveOperations.erase(operation.ptr);
1230   mlirOperationDestroy(operation);
1231   valid = false;
1232 }
1233 
1234 //------------------------------------------------------------------------------
1235 // PyOpView
1236 //------------------------------------------------------------------------------
1237 
1238 py::object PyOpView::buildGeneric(
1239     const py::object &cls, py::list resultTypeList, py::list operandList,
1240     llvm::Optional<py::dict> attributes,
1241     llvm::Optional<std::vector<PyBlock *>> successors,
1242     llvm::Optional<int> regions, DefaultingPyLocation location,
1243     const py::object &maybeIp) {
1244   PyMlirContextRef context = location->getContext();
1245   // Class level operation construction metadata.
1246   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1247   // Operand and result segment specs are either none, which does no
1248   // variadic unpacking, or a list of ints with segment sizes, where each
1249   // element is either a positive number (typically 1 for a scalar) or -1 to
1250   // indicate that it is derived from the length of the same-indexed operand
1251   // or result (implying that it is a list at that position).
1252   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1253   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1254 
1255   std::vector<uint32_t> operandSegmentLengths;
1256   std::vector<uint32_t> resultSegmentLengths;
1257 
1258   // Validate/determine region count.
1259   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1260   int opMinRegionCount = std::get<0>(opRegionSpec);
1261   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1262   if (!regions) {
1263     regions = opMinRegionCount;
1264   }
1265   if (*regions < opMinRegionCount) {
1266     throw py::value_error(
1267         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1268          llvm::Twine(opMinRegionCount) +
1269          " regions but was built with regions=" + llvm::Twine(*regions))
1270             .str());
1271   }
1272   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1273     throw py::value_error(
1274         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1275          llvm::Twine(opMinRegionCount) +
1276          " regions but was built with regions=" + llvm::Twine(*regions))
1277             .str());
1278   }
1279 
1280   // Unpack results.
1281   std::vector<PyType *> resultTypes;
1282   resultTypes.reserve(resultTypeList.size());
1283   if (resultSegmentSpecObj.is_none()) {
1284     // Non-variadic result unpacking.
1285     for (const auto &it : llvm::enumerate(resultTypeList)) {
1286       try {
1287         resultTypes.push_back(py::cast<PyType *>(it.value()));
1288         if (!resultTypes.back())
1289           throw py::cast_error();
1290       } catch (py::cast_error &err) {
1291         throw py::value_error((llvm::Twine("Result ") +
1292                                llvm::Twine(it.index()) + " of operation \"" +
1293                                name + "\" must be a Type (" + err.what() + ")")
1294                                   .str());
1295       }
1296     }
1297   } else {
1298     // Sized result unpacking.
1299     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1300     if (resultSegmentSpec.size() != resultTypeList.size()) {
1301       throw py::value_error((llvm::Twine("Operation \"") + name +
1302                              "\" requires " +
1303                              llvm::Twine(resultSegmentSpec.size()) +
1304                              " result segments but was provided " +
1305                              llvm::Twine(resultTypeList.size()))
1306                                 .str());
1307     }
1308     resultSegmentLengths.reserve(resultTypeList.size());
1309     for (const auto &it :
1310          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1311       int segmentSpec = std::get<1>(it.value());
1312       if (segmentSpec == 1 || segmentSpec == 0) {
1313         // Unpack unary element.
1314         try {
1315           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1316           if (resultType) {
1317             resultTypes.push_back(resultType);
1318             resultSegmentLengths.push_back(1);
1319           } else if (segmentSpec == 0) {
1320             // Allowed to be optional.
1321             resultSegmentLengths.push_back(0);
1322           } else {
1323             throw py::cast_error("was None and result is not optional");
1324           }
1325         } catch (py::cast_error &err) {
1326           throw py::value_error((llvm::Twine("Result ") +
1327                                  llvm::Twine(it.index()) + " of operation \"" +
1328                                  name + "\" must be a Type (" + err.what() +
1329                                  ")")
1330                                     .str());
1331         }
1332       } else if (segmentSpec == -1) {
1333         // Unpack sequence by appending.
1334         try {
1335           if (std::get<0>(it.value()).is_none()) {
1336             // Treat it as an empty list.
1337             resultSegmentLengths.push_back(0);
1338           } else {
1339             // Unpack the list.
1340             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1341             for (py::object segmentItem : segment) {
1342               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1343               if (!resultTypes.back()) {
1344                 throw py::cast_error("contained a None item");
1345               }
1346             }
1347             resultSegmentLengths.push_back(segment.size());
1348           }
1349         } catch (std::exception &err) {
1350           // NOTE: Sloppy to be using a catch-all here, but there are at least
1351           // three different unrelated exceptions that can be thrown in the
1352           // above "casts". Just keep the scope above small and catch them all.
1353           throw py::value_error((llvm::Twine("Result ") +
1354                                  llvm::Twine(it.index()) + " of operation \"" +
1355                                  name + "\" must be a Sequence of Types (" +
1356                                  err.what() + ")")
1357                                     .str());
1358         }
1359       } else {
1360         throw py::value_error("Unexpected segment spec");
1361       }
1362     }
1363   }
1364 
1365   // Unpack operands.
1366   std::vector<PyValue *> operands;
1367   operands.reserve(operands.size());
1368   if (operandSegmentSpecObj.is_none()) {
1369     // Non-sized operand unpacking.
1370     for (const auto &it : llvm::enumerate(operandList)) {
1371       try {
1372         operands.push_back(py::cast<PyValue *>(it.value()));
1373         if (!operands.back())
1374           throw py::cast_error();
1375       } catch (py::cast_error &err) {
1376         throw py::value_error((llvm::Twine("Operand ") +
1377                                llvm::Twine(it.index()) + " of operation \"" +
1378                                name + "\" must be a Value (" + err.what() + ")")
1379                                   .str());
1380       }
1381     }
1382   } else {
1383     // Sized operand unpacking.
1384     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1385     if (operandSegmentSpec.size() != operandList.size()) {
1386       throw py::value_error((llvm::Twine("Operation \"") + name +
1387                              "\" requires " +
1388                              llvm::Twine(operandSegmentSpec.size()) +
1389                              "operand segments but was provided " +
1390                              llvm::Twine(operandList.size()))
1391                                 .str());
1392     }
1393     operandSegmentLengths.reserve(operandList.size());
1394     for (const auto &it :
1395          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1396       int segmentSpec = std::get<1>(it.value());
1397       if (segmentSpec == 1 || segmentSpec == 0) {
1398         // Unpack unary element.
1399         try {
1400           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1401           if (operandValue) {
1402             operands.push_back(operandValue);
1403             operandSegmentLengths.push_back(1);
1404           } else if (segmentSpec == 0) {
1405             // Allowed to be optional.
1406             operandSegmentLengths.push_back(0);
1407           } else {
1408             throw py::cast_error("was None and operand is not optional");
1409           }
1410         } catch (py::cast_error &err) {
1411           throw py::value_error((llvm::Twine("Operand ") +
1412                                  llvm::Twine(it.index()) + " of operation \"" +
1413                                  name + "\" must be a Value (" + err.what() +
1414                                  ")")
1415                                     .str());
1416         }
1417       } else if (segmentSpec == -1) {
1418         // Unpack sequence by appending.
1419         try {
1420           if (std::get<0>(it.value()).is_none()) {
1421             // Treat it as an empty list.
1422             operandSegmentLengths.push_back(0);
1423           } else {
1424             // Unpack the list.
1425             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1426             for (py::object segmentItem : segment) {
1427               operands.push_back(py::cast<PyValue *>(segmentItem));
1428               if (!operands.back()) {
1429                 throw py::cast_error("contained a None item");
1430               }
1431             }
1432             operandSegmentLengths.push_back(segment.size());
1433           }
1434         } catch (std::exception &err) {
1435           // NOTE: Sloppy to be using a catch-all here, but there are at least
1436           // three different unrelated exceptions that can be thrown in the
1437           // above "casts". Just keep the scope above small and catch them all.
1438           throw py::value_error((llvm::Twine("Operand ") +
1439                                  llvm::Twine(it.index()) + " of operation \"" +
1440                                  name + "\" must be a Sequence of Values (" +
1441                                  err.what() + ")")
1442                                     .str());
1443         }
1444       } else {
1445         throw py::value_error("Unexpected segment spec");
1446       }
1447     }
1448   }
1449 
1450   // Merge operand/result segment lengths into attributes if needed.
1451   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1452     // Dup.
1453     if (attributes) {
1454       attributes = py::dict(*attributes);
1455     } else {
1456       attributes = py::dict();
1457     }
1458     if (attributes->contains("result_segment_sizes") ||
1459         attributes->contains("operand_segment_sizes")) {
1460       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1461                             "'operand_segment_sizes' attribute is unsupported. "
1462                             "Use Operation.create for such low-level access.");
1463     }
1464 
1465     // Add result_segment_sizes attribute.
1466     if (!resultSegmentLengths.empty()) {
1467       int64_t size = resultSegmentLengths.size();
1468       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1469           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1470           resultSegmentLengths.size(), resultSegmentLengths.data());
1471       (*attributes)["result_segment_sizes"] =
1472           PyAttribute(context, segmentLengthAttr);
1473     }
1474 
1475     // Add operand_segment_sizes attribute.
1476     if (!operandSegmentLengths.empty()) {
1477       int64_t size = operandSegmentLengths.size();
1478       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1479           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1480           operandSegmentLengths.size(), operandSegmentLengths.data());
1481       (*attributes)["operand_segment_sizes"] =
1482           PyAttribute(context, segmentLengthAttr);
1483     }
1484   }
1485 
1486   // Delegate to create.
1487   return PyOperation::create(name,
1488                              /*results=*/std::move(resultTypes),
1489                              /*operands=*/std::move(operands),
1490                              /*attributes=*/std::move(attributes),
1491                              /*successors=*/std::move(successors),
1492                              /*regions=*/*regions, location, maybeIp);
1493 }
1494 
1495 PyOpView::PyOpView(const py::object &operationObject)
1496     // Casting through the PyOperationBase base-class and then back to the
1497     // Operation lets us accept any PyOperationBase subclass.
1498     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1499       operationObject(operation.getRef().getObject()) {}
1500 
1501 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1502   // This is... a little gross. The typical pattern is to have a pure python
1503   // class that extends OpView like:
1504   //   class AddFOp(_cext.ir.OpView):
1505   //     def __init__(self, loc, lhs, rhs):
1506   //       operation = loc.context.create_operation(
1507   //           "addf", lhs, rhs, results=[lhs.type])
1508   //       super().__init__(operation)
1509   //
1510   // I.e. The goal of the user facing type is to provide a nice constructor
1511   // that has complete freedom for the op under construction. This is at odds
1512   // with our other desire to sometimes create this object by just passing an
1513   // operation (to initialize the base class). We could do *arg and **kwargs
1514   // munging to try to make it work, but instead, we synthesize a new class
1515   // on the fly which extends this user class (AddFOp in this example) and
1516   // *give it* the base class's __init__ method, thus bypassing the
1517   // intermediate subclass's __init__ method entirely. While slightly,
1518   // underhanded, this is safe/legal because the type hierarchy has not changed
1519   // (we just added a new leaf) and we aren't mucking around with __new__.
1520   // Typically, this new class will be stored on the original as "_Raw" and will
1521   // be used for casts and other things that need a variant of the class that
1522   // is initialized purely from an operation.
1523   py::object parentMetaclass =
1524       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1525   py::dict attributes;
1526   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1527   // now.
1528   //   auto opViewType = py::type::of<PyOpView>();
1529   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1530   attributes["__init__"] = opViewType.attr("__init__");
1531   py::str origName = userClass.attr("__name__");
1532   py::str newName = py::str("_") + origName;
1533   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1534 }
1535 
1536 //------------------------------------------------------------------------------
1537 // PyInsertionPoint.
1538 //------------------------------------------------------------------------------
1539 
1540 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1541 
1542 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1543     : refOperation(beforeOperationBase.getOperation().getRef()),
1544       block((*refOperation)->getBlock()) {}
1545 
1546 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1547   PyOperation &operation = operationBase.getOperation();
1548   if (operation.isAttached())
1549     throw SetPyError(PyExc_ValueError,
1550                      "Attempt to insert operation that is already attached");
1551   block.getParentOperation()->checkValid();
1552   MlirOperation beforeOp = {nullptr};
1553   if (refOperation) {
1554     // Insert before operation.
1555     (*refOperation)->checkValid();
1556     beforeOp = (*refOperation)->get();
1557   } else {
1558     // Insert at end (before null) is only valid if the block does not
1559     // already end in a known terminator (violating this will cause assertion
1560     // failures later).
1561     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1562       throw py::index_error("Cannot insert operation at the end of a block "
1563                             "that already has a terminator. Did you mean to "
1564                             "use 'InsertionPoint.at_block_terminator(block)' "
1565                             "versus 'InsertionPoint(block)'?");
1566     }
1567   }
1568   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1569   operation.setAttached();
1570 }
1571 
1572 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1573   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1574   if (mlirOperationIsNull(firstOp)) {
1575     // Just insert at end.
1576     return PyInsertionPoint(block);
1577   }
1578 
1579   // Insert before first op.
1580   PyOperationRef firstOpRef = PyOperation::forOperation(
1581       block.getParentOperation()->getContext(), firstOp);
1582   return PyInsertionPoint{block, std::move(firstOpRef)};
1583 }
1584 
1585 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1586   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1587   if (mlirOperationIsNull(terminator))
1588     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1589   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1590       block.getParentOperation()->getContext(), terminator);
1591   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1592 }
1593 
1594 py::object PyInsertionPoint::contextEnter() {
1595   return PyThreadContextEntry::pushInsertionPoint(*this);
1596 }
1597 
1598 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1599                                    const pybind11::object &excVal,
1600                                    const pybind11::object &excTb) {
1601   PyThreadContextEntry::popInsertionPoint(*this);
1602 }
1603 
1604 //------------------------------------------------------------------------------
1605 // PyAttribute.
1606 //------------------------------------------------------------------------------
1607 
1608 bool PyAttribute::operator==(const PyAttribute &other) {
1609   return mlirAttributeEqual(attr, other.attr);
1610 }
1611 
1612 py::object PyAttribute::getCapsule() {
1613   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1614 }
1615 
1616 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1617   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1618   if (mlirAttributeIsNull(rawAttr))
1619     throw py::error_already_set();
1620   return PyAttribute(
1621       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1622 }
1623 
1624 //------------------------------------------------------------------------------
1625 // PyNamedAttribute.
1626 //------------------------------------------------------------------------------
1627 
1628 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1629     : ownedName(new std::string(std::move(ownedName))) {
1630   namedAttr = mlirNamedAttributeGet(
1631       mlirIdentifierGet(mlirAttributeGetContext(attr),
1632                         toMlirStringRef(*this->ownedName)),
1633       attr);
1634 }
1635 
1636 //------------------------------------------------------------------------------
1637 // PyType.
1638 //------------------------------------------------------------------------------
1639 
1640 bool PyType::operator==(const PyType &other) {
1641   return mlirTypeEqual(type, other.type);
1642 }
1643 
1644 py::object PyType::getCapsule() {
1645   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1646 }
1647 
1648 PyType PyType::createFromCapsule(py::object capsule) {
1649   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1650   if (mlirTypeIsNull(rawType))
1651     throw py::error_already_set();
1652   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1653                 rawType);
1654 }
1655 
1656 //------------------------------------------------------------------------------
1657 // PyValue and subclases.
1658 //------------------------------------------------------------------------------
1659 
1660 pybind11::object PyValue::getCapsule() {
1661   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1662 }
1663 
1664 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1665   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1666   if (mlirValueIsNull(value))
1667     throw py::error_already_set();
1668   MlirOperation owner;
1669   if (mlirValueIsAOpResult(value))
1670     owner = mlirOpResultGetOwner(value);
1671   if (mlirValueIsABlockArgument(value))
1672     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1673   if (mlirOperationIsNull(owner))
1674     throw py::error_already_set();
1675   MlirContext ctx = mlirOperationGetContext(owner);
1676   PyOperationRef ownerRef =
1677       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1678   return PyValue(ownerRef, value);
1679 }
1680 
1681 //------------------------------------------------------------------------------
1682 // PySymbolTable.
1683 //------------------------------------------------------------------------------
1684 
1685 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1686     : operation(operation.getOperation().getRef()) {
1687   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1688   if (mlirSymbolTableIsNull(symbolTable)) {
1689     throw py::cast_error("Operation is not a Symbol Table.");
1690   }
1691 }
1692 
1693 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1694   operation->checkValid();
1695   MlirOperation symbol = mlirSymbolTableLookup(
1696       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1697   if (mlirOperationIsNull(symbol))
1698     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1699 
1700   return PyOperation::forOperation(operation->getContext(), symbol,
1701                                    operation.getObject())
1702       ->createOpView();
1703 }
1704 
1705 void PySymbolTable::erase(PyOperationBase &symbol) {
1706   operation->checkValid();
1707   symbol.getOperation().checkValid();
1708   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1709   // The operation is also erased, so we must invalidate it. There may be Python
1710   // references to this operation so we don't want to delete it from the list of
1711   // live operations here.
1712   symbol.getOperation().valid = false;
1713 }
1714 
1715 void PySymbolTable::dunderDel(const std::string &name) {
1716   py::object operation = dunderGetItem(name);
1717   erase(py::cast<PyOperationBase &>(operation));
1718 }
1719 
1720 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1721   operation->checkValid();
1722   symbol.getOperation().checkValid();
1723   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1724       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1725   if (mlirAttributeIsNull(symbolAttr))
1726     throw py::value_error("Expected operation to have a symbol name.");
1727   return PyAttribute(
1728       symbol.getOperation().getContext(),
1729       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1730 }
1731 
1732 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1733   // Op must already be a symbol.
1734   PyOperation &operation = symbol.getOperation();
1735   operation.checkValid();
1736   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1737   MlirAttribute existingNameAttr =
1738       mlirOperationGetAttributeByName(operation.get(), attrName);
1739   if (mlirAttributeIsNull(existingNameAttr))
1740     throw py::value_error("Expected operation to have a symbol name.");
1741   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1742 }
1743 
1744 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1745                                   const std::string &name) {
1746   // Op must already be a symbol.
1747   PyOperation &operation = symbol.getOperation();
1748   operation.checkValid();
1749   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1750   MlirAttribute existingNameAttr =
1751       mlirOperationGetAttributeByName(operation.get(), attrName);
1752   if (mlirAttributeIsNull(existingNameAttr))
1753     throw py::value_error("Expected operation to have a symbol name.");
1754   MlirAttribute newNameAttr =
1755       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1756   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1757 }
1758 
1759 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1760   PyOperation &operation = symbol.getOperation();
1761   operation.checkValid();
1762   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1763   MlirAttribute existingVisAttr =
1764       mlirOperationGetAttributeByName(operation.get(), attrName);
1765   if (mlirAttributeIsNull(existingVisAttr))
1766     throw py::value_error("Expected operation to have a symbol visibility.");
1767   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1768 }
1769 
1770 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1771                                   const std::string &visibility) {
1772   if (visibility != "public" && visibility != "private" &&
1773       visibility != "nested")
1774     throw py::value_error(
1775         "Expected visibility to be 'public', 'private' or 'nested'");
1776   PyOperation &operation = symbol.getOperation();
1777   operation.checkValid();
1778   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1779   MlirAttribute existingVisAttr =
1780       mlirOperationGetAttributeByName(operation.get(), attrName);
1781   if (mlirAttributeIsNull(existingVisAttr))
1782     throw py::value_error("Expected operation to have a symbol visibility.");
1783   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1784                                                toMlirStringRef(visibility));
1785   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1786 }
1787 
1788 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1789                                          const std::string &newSymbol,
1790                                          PyOperationBase &from) {
1791   PyOperation &fromOperation = from.getOperation();
1792   fromOperation.checkValid();
1793   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1794           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1795           from.getOperation())))
1796 
1797     throw py::value_error("Symbol rename failed");
1798 }
1799 
1800 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1801                                      bool allSymUsesVisible,
1802                                      py::object callback) {
1803   PyOperation &fromOperation = from.getOperation();
1804   fromOperation.checkValid();
1805   struct UserData {
1806     PyMlirContextRef context;
1807     py::object callback;
1808     bool gotException;
1809     std::string exceptionWhat;
1810     py::object exceptionType;
1811   };
1812   UserData userData{
1813       fromOperation.getContext(), std::move(callback), false, {}, {}};
1814   mlirSymbolTableWalkSymbolTables(
1815       fromOperation.get(), allSymUsesVisible,
1816       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1817         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1818         auto pyFoundOp =
1819             PyOperation::forOperation(calleeUserData->context, foundOp);
1820         if (calleeUserData->gotException)
1821           return;
1822         try {
1823           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1824         } catch (py::error_already_set &e) {
1825           calleeUserData->gotException = true;
1826           calleeUserData->exceptionWhat = e.what();
1827           calleeUserData->exceptionType = e.type();
1828         }
1829       },
1830       static_cast<void *>(&userData));
1831   if (userData.gotException) {
1832     std::string message("Exception raised in callback: ");
1833     message.append(userData.exceptionWhat);
1834     throw std::runtime_error(message);
1835   }
1836 }
1837 
1838 namespace {
1839 /// CRTP base class for Python MLIR values that subclass Value and should be
1840 /// castable from it. The value hierarchy is one level deep and is not supposed
1841 /// to accommodate other levels unless core MLIR changes.
1842 template <typename DerivedTy>
1843 class PyConcreteValue : public PyValue {
1844 public:
1845   // Derived classes must define statics for:
1846   //   IsAFunctionTy isaFunction
1847   //   const char *pyClassName
1848   // and redefine bindDerived.
1849   using ClassTy = py::class_<DerivedTy, PyValue>;
1850   using IsAFunctionTy = bool (*)(MlirValue);
1851 
1852   PyConcreteValue() = default;
1853   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1854       : PyValue(operationRef, value) {}
1855   PyConcreteValue(PyValue &orig)
1856       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1857 
1858   /// Attempts to cast the original value to the derived type and throws on
1859   /// type mismatches.
1860   static MlirValue castFrom(PyValue &orig) {
1861     if (!DerivedTy::isaFunction(orig.get())) {
1862       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1863       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1864                                              DerivedTy::pyClassName +
1865                                              " (from " + origRepr + ")");
1866     }
1867     return orig.get();
1868   }
1869 
1870   /// Binds the Python module objects to functions of this class.
1871   static void bind(py::module &m) {
1872     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1873     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1874     cls.def_static(
1875         "isinstance",
1876         [](PyValue &otherValue) -> bool {
1877           return DerivedTy::isaFunction(otherValue);
1878         },
1879         py::arg("other_value"));
1880     DerivedTy::bindDerived(cls);
1881   }
1882 
1883   /// Implemented by derived classes to add methods to the Python subclass.
1884   static void bindDerived(ClassTy &m) {}
1885 };
1886 
1887 /// Python wrapper for MlirBlockArgument.
1888 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1889 public:
1890   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1891   static constexpr const char *pyClassName = "BlockArgument";
1892   using PyConcreteValue::PyConcreteValue;
1893 
1894   static void bindDerived(ClassTy &c) {
1895     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1896       return PyBlock(self.getParentOperation(),
1897                      mlirBlockArgumentGetOwner(self.get()));
1898     });
1899     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1900       return mlirBlockArgumentGetArgNumber(self.get());
1901     });
1902     c.def(
1903         "set_type",
1904         [](PyBlockArgument &self, PyType type) {
1905           return mlirBlockArgumentSetType(self.get(), type);
1906         },
1907         py::arg("type"));
1908   }
1909 };
1910 
1911 /// Python wrapper for MlirOpResult.
1912 class PyOpResult : public PyConcreteValue<PyOpResult> {
1913 public:
1914   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1915   static constexpr const char *pyClassName = "OpResult";
1916   using PyConcreteValue::PyConcreteValue;
1917 
1918   static void bindDerived(ClassTy &c) {
1919     c.def_property_readonly("owner", [](PyOpResult &self) {
1920       assert(
1921           mlirOperationEqual(self.getParentOperation()->get(),
1922                              mlirOpResultGetOwner(self.get())) &&
1923           "expected the owner of the value in Python to match that in the IR");
1924       return self.getParentOperation().getObject();
1925     });
1926     c.def_property_readonly("result_number", [](PyOpResult &self) {
1927       return mlirOpResultGetResultNumber(self.get());
1928     });
1929   }
1930 };
1931 
1932 /// Returns the list of types of the values held by container.
1933 template <typename Container>
1934 static std::vector<PyType> getValueTypes(Container &container,
1935                                          PyMlirContextRef &context) {
1936   std::vector<PyType> result;
1937   result.reserve(container.getNumElements());
1938   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1939     result.push_back(
1940         PyType(context, mlirValueGetType(container.getElement(i).get())));
1941   }
1942   return result;
1943 }
1944 
1945 /// A list of block arguments. Internally, these are stored as consecutive
1946 /// elements, random access is cheap. The argument list is associated with the
1947 /// operation that contains the block (detached blocks are not allowed in
1948 /// Python bindings) and extends its lifetime.
1949 class PyBlockArgumentList
1950     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1951 public:
1952   static constexpr const char *pyClassName = "BlockArgumentList";
1953 
1954   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1955                       intptr_t startIndex = 0, intptr_t length = -1,
1956                       intptr_t step = 1)
1957       : Sliceable(startIndex,
1958                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1959                   step),
1960         operation(std::move(operation)), block(block) {}
1961 
1962   /// Returns the number of arguments in the list.
1963   intptr_t getNumElements() {
1964     operation->checkValid();
1965     return mlirBlockGetNumArguments(block);
1966   }
1967 
1968   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1969   PyBlockArgument getElement(intptr_t pos) {
1970     MlirValue argument = mlirBlockGetArgument(block, pos);
1971     return PyBlockArgument(operation, argument);
1972   }
1973 
1974   /// Returns a sublist of this list.
1975   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1976                             intptr_t step) {
1977     return PyBlockArgumentList(operation, block, startIndex, length, step);
1978   }
1979 
1980   static void bindDerived(ClassTy &c) {
1981     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1982       return getValueTypes(self, self.operation->getContext());
1983     });
1984   }
1985 
1986 private:
1987   PyOperationRef operation;
1988   MlirBlock block;
1989 };
1990 
1991 /// A list of operation operands. Internally, these are stored as consecutive
1992 /// elements, random access is cheap. The result list is associated with the
1993 /// operation whose results these are, and extends the lifetime of this
1994 /// operation.
1995 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1996 public:
1997   static constexpr const char *pyClassName = "OpOperandList";
1998 
1999   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2000                   intptr_t length = -1, intptr_t step = 1)
2001       : Sliceable(startIndex,
2002                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2003                                : length,
2004                   step),
2005         operation(operation) {}
2006 
2007   intptr_t getNumElements() {
2008     operation->checkValid();
2009     return mlirOperationGetNumOperands(operation->get());
2010   }
2011 
2012   PyValue getElement(intptr_t pos) {
2013     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2014     MlirOperation owner;
2015     if (mlirValueIsAOpResult(operand))
2016       owner = mlirOpResultGetOwner(operand);
2017     else if (mlirValueIsABlockArgument(operand))
2018       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2019     else
2020       assert(false && "Value must be an block arg or op result.");
2021     PyOperationRef pyOwner =
2022         PyOperation::forOperation(operation->getContext(), owner);
2023     return PyValue(pyOwner, operand);
2024   }
2025 
2026   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2027     return PyOpOperandList(operation, startIndex, length, step);
2028   }
2029 
2030   void dunderSetItem(intptr_t index, PyValue value) {
2031     index = wrapIndex(index);
2032     mlirOperationSetOperand(operation->get(), index, value.get());
2033   }
2034 
2035   static void bindDerived(ClassTy &c) {
2036     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2037   }
2038 
2039 private:
2040   PyOperationRef operation;
2041 };
2042 
2043 /// A list of operation results. Internally, these are stored as consecutive
2044 /// elements, random access is cheap. The result list is associated with the
2045 /// operation whose results these are, and extends the lifetime of this
2046 /// operation.
2047 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2048 public:
2049   static constexpr const char *pyClassName = "OpResultList";
2050 
2051   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2052                  intptr_t length = -1, intptr_t step = 1)
2053       : Sliceable(startIndex,
2054                   length == -1 ? mlirOperationGetNumResults(operation->get())
2055                                : length,
2056                   step),
2057         operation(operation) {}
2058 
2059   intptr_t getNumElements() {
2060     operation->checkValid();
2061     return mlirOperationGetNumResults(operation->get());
2062   }
2063 
2064   PyOpResult getElement(intptr_t index) {
2065     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2066     return PyOpResult(value);
2067   }
2068 
2069   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2070     return PyOpResultList(operation, startIndex, length, step);
2071   }
2072 
2073   static void bindDerived(ClassTy &c) {
2074     c.def_property_readonly("types", [](PyOpResultList &self) {
2075       return getValueTypes(self, self.operation->getContext());
2076     });
2077   }
2078 
2079 private:
2080   PyOperationRef operation;
2081 };
2082 
2083 /// A list of operation attributes. Can be indexed by name, producing
2084 /// attributes, or by index, producing named attributes.
2085 class PyOpAttributeMap {
2086 public:
2087   PyOpAttributeMap(PyOperationRef operation)
2088       : operation(std::move(operation)) {}
2089 
2090   PyAttribute dunderGetItemNamed(const std::string &name) {
2091     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2092                                                          toMlirStringRef(name));
2093     if (mlirAttributeIsNull(attr)) {
2094       throw SetPyError(PyExc_KeyError,
2095                        "attempt to access a non-existent attribute");
2096     }
2097     return PyAttribute(operation->getContext(), attr);
2098   }
2099 
2100   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2101     if (index < 0 || index >= dunderLen()) {
2102       throw SetPyError(PyExc_IndexError,
2103                        "attempt to access out of bounds attribute");
2104     }
2105     MlirNamedAttribute namedAttr =
2106         mlirOperationGetAttribute(operation->get(), index);
2107     return PyNamedAttribute(
2108         namedAttr.attribute,
2109         std::string(mlirIdentifierStr(namedAttr.name).data,
2110                     mlirIdentifierStr(namedAttr.name).length));
2111   }
2112 
2113   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2114     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2115                                     attr);
2116   }
2117 
2118   void dunderDelItem(const std::string &name) {
2119     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2120                                                      toMlirStringRef(name));
2121     if (!removed)
2122       throw SetPyError(PyExc_KeyError,
2123                        "attempt to delete a non-existent attribute");
2124   }
2125 
2126   intptr_t dunderLen() {
2127     return mlirOperationGetNumAttributes(operation->get());
2128   }
2129 
2130   bool dunderContains(const std::string &name) {
2131     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2132         operation->get(), toMlirStringRef(name)));
2133   }
2134 
2135   static void bind(py::module &m) {
2136     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2137         .def("__contains__", &PyOpAttributeMap::dunderContains)
2138         .def("__len__", &PyOpAttributeMap::dunderLen)
2139         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2140         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2141         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2142         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2143   }
2144 
2145 private:
2146   PyOperationRef operation;
2147 };
2148 
2149 } // namespace
2150 
2151 //------------------------------------------------------------------------------
2152 // Populates the core exports of the 'ir' submodule.
2153 //------------------------------------------------------------------------------
2154 
2155 void mlir::python::populateIRCore(py::module &m) {
2156   //----------------------------------------------------------------------------
2157   // Enums.
2158   //----------------------------------------------------------------------------
2159   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2160       .value("ERROR", MlirDiagnosticError)
2161       .value("WARNING", MlirDiagnosticWarning)
2162       .value("NOTE", MlirDiagnosticNote)
2163       .value("REMARK", MlirDiagnosticRemark);
2164 
2165   //----------------------------------------------------------------------------
2166   // Mapping of Diagnostics.
2167   //----------------------------------------------------------------------------
2168   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2169       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2170       .def_property_readonly("location", &PyDiagnostic::getLocation)
2171       .def_property_readonly("message", &PyDiagnostic::getMessage)
2172       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2173       .def("__str__", [](PyDiagnostic &self) -> py::str {
2174         if (!self.isValid())
2175           return "<Invalid Diagnostic>";
2176         return self.getMessage();
2177       });
2178 
2179   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2180       .def("detach", &PyDiagnosticHandler::detach)
2181       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2182       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2183       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2184       .def("__exit__", &PyDiagnosticHandler::contextExit);
2185 
2186   //----------------------------------------------------------------------------
2187   // Mapping of MlirContext.
2188   //----------------------------------------------------------------------------
2189   py::class_<PyMlirContext>(m, "Context", py::module_local())
2190       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2191       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2192       .def("_get_context_again",
2193            [](PyMlirContext &self) {
2194              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2195              return ref.releaseObject();
2196            })
2197       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2198       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2199       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2200                              &PyMlirContext::getCapsule)
2201       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2202       .def("__enter__", &PyMlirContext::contextEnter)
2203       .def("__exit__", &PyMlirContext::contextExit)
2204       .def_property_readonly_static(
2205           "current",
2206           [](py::object & /*class*/) {
2207             auto *context = PyThreadContextEntry::getDefaultContext();
2208             if (!context)
2209               throw SetPyError(PyExc_ValueError, "No current Context");
2210             return context;
2211           },
2212           "Gets the Context bound to the current thread or raises ValueError")
2213       .def_property_readonly(
2214           "dialects",
2215           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2216           "Gets a container for accessing dialects by name")
2217       .def_property_readonly(
2218           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2219           "Alias for 'dialect'")
2220       .def(
2221           "get_dialect_descriptor",
2222           [=](PyMlirContext &self, std::string &name) {
2223             MlirDialect dialect = mlirContextGetOrLoadDialect(
2224                 self.get(), {name.data(), name.size()});
2225             if (mlirDialectIsNull(dialect)) {
2226               throw SetPyError(PyExc_ValueError,
2227                                Twine("Dialect '") + name + "' not found");
2228             }
2229             return PyDialectDescriptor(self.getRef(), dialect);
2230           },
2231           py::arg("dialect_name"),
2232           "Gets or loads a dialect by name, returning its descriptor object")
2233       .def_property(
2234           "allow_unregistered_dialects",
2235           [](PyMlirContext &self) -> bool {
2236             return mlirContextGetAllowUnregisteredDialects(self.get());
2237           },
2238           [](PyMlirContext &self, bool value) {
2239             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2240           })
2241       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2242            py::arg("callback"),
2243            "Attaches a diagnostic handler that will receive callbacks")
2244       .def(
2245           "enable_multithreading",
2246           [](PyMlirContext &self, bool enable) {
2247             mlirContextEnableMultithreading(self.get(), enable);
2248           },
2249           py::arg("enable"))
2250       .def(
2251           "is_registered_operation",
2252           [](PyMlirContext &self, std::string &name) {
2253             return mlirContextIsRegisteredOperation(
2254                 self.get(), MlirStringRef{name.data(), name.size()});
2255           },
2256           py::arg("operation_name"));
2257 
2258   //----------------------------------------------------------------------------
2259   // Mapping of PyDialectDescriptor
2260   //----------------------------------------------------------------------------
2261   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2262       .def_property_readonly("namespace",
2263                              [](PyDialectDescriptor &self) {
2264                                MlirStringRef ns =
2265                                    mlirDialectGetNamespace(self.get());
2266                                return py::str(ns.data, ns.length);
2267                              })
2268       .def("__repr__", [](PyDialectDescriptor &self) {
2269         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2270         std::string repr("<DialectDescriptor ");
2271         repr.append(ns.data, ns.length);
2272         repr.append(">");
2273         return repr;
2274       });
2275 
2276   //----------------------------------------------------------------------------
2277   // Mapping of PyDialects
2278   //----------------------------------------------------------------------------
2279   py::class_<PyDialects>(m, "Dialects", py::module_local())
2280       .def("__getitem__",
2281            [=](PyDialects &self, std::string keyName) {
2282              MlirDialect dialect =
2283                  self.getDialectForKey(keyName, /*attrError=*/false);
2284              py::object descriptor =
2285                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2286              return createCustomDialectWrapper(keyName, std::move(descriptor));
2287            })
2288       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2289         MlirDialect dialect =
2290             self.getDialectForKey(attrName, /*attrError=*/true);
2291         py::object descriptor =
2292             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2293         return createCustomDialectWrapper(attrName, std::move(descriptor));
2294       });
2295 
2296   //----------------------------------------------------------------------------
2297   // Mapping of PyDialect
2298   //----------------------------------------------------------------------------
2299   py::class_<PyDialect>(m, "Dialect", py::module_local())
2300       .def(py::init<py::object>(), py::arg("descriptor"))
2301       .def_property_readonly(
2302           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2303       .def("__repr__", [](py::object self) {
2304         auto clazz = self.attr("__class__");
2305         return py::str("<Dialect ") +
2306                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2307                clazz.attr("__module__") + py::str(".") +
2308                clazz.attr("__name__") + py::str(")>");
2309       });
2310 
2311   //----------------------------------------------------------------------------
2312   // Mapping of Location
2313   //----------------------------------------------------------------------------
2314   py::class_<PyLocation>(m, "Location", py::module_local())
2315       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2316       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2317       .def("__enter__", &PyLocation::contextEnter)
2318       .def("__exit__", &PyLocation::contextExit)
2319       .def("__eq__",
2320            [](PyLocation &self, PyLocation &other) -> bool {
2321              return mlirLocationEqual(self, other);
2322            })
2323       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2324       .def_property_readonly_static(
2325           "current",
2326           [](py::object & /*class*/) {
2327             auto *loc = PyThreadContextEntry::getDefaultLocation();
2328             if (!loc)
2329               throw SetPyError(PyExc_ValueError, "No current Location");
2330             return loc;
2331           },
2332           "Gets the Location bound to the current thread or raises ValueError")
2333       .def_static(
2334           "unknown",
2335           [](DefaultingPyMlirContext context) {
2336             return PyLocation(context->getRef(),
2337                               mlirLocationUnknownGet(context->get()));
2338           },
2339           py::arg("context") = py::none(),
2340           "Gets a Location representing an unknown location")
2341       .def_static(
2342           "callsite",
2343           [](PyLocation callee, const std::vector<PyLocation> &frames,
2344              DefaultingPyMlirContext context) {
2345             if (frames.empty())
2346               throw py::value_error("No caller frames provided");
2347             MlirLocation caller = frames.back().get();
2348             for (const PyLocation &frame :
2349                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2350               caller = mlirLocationCallSiteGet(frame.get(), caller);
2351             return PyLocation(context->getRef(),
2352                               mlirLocationCallSiteGet(callee.get(), caller));
2353           },
2354           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2355           kContextGetCallSiteLocationDocstring)
2356       .def_static(
2357           "file",
2358           [](std::string filename, int line, int col,
2359              DefaultingPyMlirContext context) {
2360             return PyLocation(
2361                 context->getRef(),
2362                 mlirLocationFileLineColGet(
2363                     context->get(), toMlirStringRef(filename), line, col));
2364           },
2365           py::arg("filename"), py::arg("line"), py::arg("col"),
2366           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2367       .def_static(
2368           "fused",
2369           [](const std::vector<PyLocation> &pyLocations,
2370              llvm::Optional<PyAttribute> metadata,
2371              DefaultingPyMlirContext context) {
2372             llvm::SmallVector<MlirLocation, 4> locations;
2373             locations.reserve(pyLocations.size());
2374             for (auto &pyLocation : pyLocations)
2375               locations.push_back(pyLocation.get());
2376             MlirLocation location = mlirLocationFusedGet(
2377                 context->get(), locations.size(), locations.data(),
2378                 metadata ? metadata->get() : MlirAttribute{0});
2379             return PyLocation(context->getRef(), location);
2380           },
2381           py::arg("locations"), py::arg("metadata") = py::none(),
2382           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2383       .def_static(
2384           "name",
2385           [](std::string name, llvm::Optional<PyLocation> childLoc,
2386              DefaultingPyMlirContext context) {
2387             return PyLocation(
2388                 context->getRef(),
2389                 mlirLocationNameGet(
2390                     context->get(), toMlirStringRef(name),
2391                     childLoc ? childLoc->get()
2392                              : mlirLocationUnknownGet(context->get())));
2393           },
2394           py::arg("name"), py::arg("childLoc") = py::none(),
2395           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2396       .def_property_readonly(
2397           "context",
2398           [](PyLocation &self) { return self.getContext().getObject(); },
2399           "Context that owns the Location")
2400       .def(
2401           "emit_error",
2402           [](PyLocation &self, std::string message) {
2403             mlirEmitError(self, message.c_str());
2404           },
2405           py::arg("message"), "Emits an error at this location")
2406       .def("__repr__", [](PyLocation &self) {
2407         PyPrintAccumulator printAccum;
2408         mlirLocationPrint(self, printAccum.getCallback(),
2409                           printAccum.getUserData());
2410         return printAccum.join();
2411       });
2412 
2413   //----------------------------------------------------------------------------
2414   // Mapping of Module
2415   //----------------------------------------------------------------------------
2416   py::class_<PyModule>(m, "Module", py::module_local())
2417       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2418       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2419       .def_static(
2420           "parse",
2421           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2422             MlirModule module = mlirModuleCreateParse(
2423                 context->get(), toMlirStringRef(moduleAsm));
2424             // TODO: Rework error reporting once diagnostic engine is exposed
2425             // in C API.
2426             if (mlirModuleIsNull(module)) {
2427               throw SetPyError(
2428                   PyExc_ValueError,
2429                   "Unable to parse module assembly (see diagnostics)");
2430             }
2431             return PyModule::forModule(module).releaseObject();
2432           },
2433           py::arg("asm"), py::arg("context") = py::none(),
2434           kModuleParseDocstring)
2435       .def_static(
2436           "create",
2437           [](DefaultingPyLocation loc) {
2438             MlirModule module = mlirModuleCreateEmpty(loc);
2439             return PyModule::forModule(module).releaseObject();
2440           },
2441           py::arg("loc") = py::none(), "Creates an empty module")
2442       .def_property_readonly(
2443           "context",
2444           [](PyModule &self) { return self.getContext().getObject(); },
2445           "Context that created the Module")
2446       .def_property_readonly(
2447           "operation",
2448           [](PyModule &self) {
2449             return PyOperation::forOperation(self.getContext(),
2450                                              mlirModuleGetOperation(self.get()),
2451                                              self.getRef().releaseObject())
2452                 .releaseObject();
2453           },
2454           "Accesses the module as an operation")
2455       .def_property_readonly(
2456           "body",
2457           [](PyModule &self) {
2458             PyOperationRef moduleOp = PyOperation::forOperation(
2459                 self.getContext(), mlirModuleGetOperation(self.get()),
2460                 self.getRef().releaseObject());
2461             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2462             return returnBlock;
2463           },
2464           "Return the block for this module")
2465       .def(
2466           "dump",
2467           [](PyModule &self) {
2468             mlirOperationDump(mlirModuleGetOperation(self.get()));
2469           },
2470           kDumpDocstring)
2471       .def(
2472           "__str__",
2473           [](py::object self) {
2474             // Defer to the operation's __str__.
2475             return self.attr("operation").attr("__str__")();
2476           },
2477           kOperationStrDunderDocstring);
2478 
2479   //----------------------------------------------------------------------------
2480   // Mapping of Operation.
2481   //----------------------------------------------------------------------------
2482   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2483       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2484                              [](PyOperationBase &self) {
2485                                return self.getOperation().getCapsule();
2486                              })
2487       .def("__eq__",
2488            [](PyOperationBase &self, PyOperationBase &other) {
2489              return &self.getOperation() == &other.getOperation();
2490            })
2491       .def("__eq__",
2492            [](PyOperationBase &self, py::object other) { return false; })
2493       .def("__hash__",
2494            [](PyOperationBase &self) {
2495              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2496            })
2497       .def_property_readonly("attributes",
2498                              [](PyOperationBase &self) {
2499                                return PyOpAttributeMap(
2500                                    self.getOperation().getRef());
2501                              })
2502       .def_property_readonly("operands",
2503                              [](PyOperationBase &self) {
2504                                return PyOpOperandList(
2505                                    self.getOperation().getRef());
2506                              })
2507       .def_property_readonly("regions",
2508                              [](PyOperationBase &self) {
2509                                return PyRegionList(
2510                                    self.getOperation().getRef());
2511                              })
2512       .def_property_readonly(
2513           "results",
2514           [](PyOperationBase &self) {
2515             return PyOpResultList(self.getOperation().getRef());
2516           },
2517           "Returns the list of Operation results.")
2518       .def_property_readonly(
2519           "result",
2520           [](PyOperationBase &self) {
2521             auto &operation = self.getOperation();
2522             auto numResults = mlirOperationGetNumResults(operation);
2523             if (numResults != 1) {
2524               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2525               throw SetPyError(
2526                   PyExc_ValueError,
2527                   Twine("Cannot call .result on operation ") +
2528                       StringRef(name.data, name.length) + " which has " +
2529                       Twine(numResults) +
2530                       " results (it is only valid for operations with a "
2531                       "single result)");
2532             }
2533             return PyOpResult(operation.getRef(),
2534                               mlirOperationGetResult(operation, 0));
2535           },
2536           "Shortcut to get an op result if it has only one (throws an error "
2537           "otherwise).")
2538       .def_property_readonly(
2539           "location",
2540           [](PyOperationBase &self) {
2541             PyOperation &operation = self.getOperation();
2542             return PyLocation(operation.getContext(),
2543                               mlirOperationGetLocation(operation.get()));
2544           },
2545           "Returns the source location the operation was defined or derived "
2546           "from.")
2547       .def(
2548           "__str__",
2549           [](PyOperationBase &self) {
2550             return self.getAsm(/*binary=*/false,
2551                                /*largeElementsLimit=*/llvm::None,
2552                                /*enableDebugInfo=*/false,
2553                                /*prettyDebugInfo=*/false,
2554                                /*printGenericOpForm=*/false,
2555                                /*useLocalScope=*/false,
2556                                /*assumeVerified=*/false);
2557           },
2558           "Returns the assembly form of the operation.")
2559       .def("print", &PyOperationBase::print,
2560            // Careful: Lots of arguments must match up with print method.
2561            py::arg("file") = py::none(), py::arg("binary") = false,
2562            py::arg("large_elements_limit") = py::none(),
2563            py::arg("enable_debug_info") = false,
2564            py::arg("pretty_debug_info") = false,
2565            py::arg("print_generic_op_form") = false,
2566            py::arg("use_local_scope") = false,
2567            py::arg("assume_verified") = false, kOperationPrintDocstring)
2568       .def("get_asm", &PyOperationBase::getAsm,
2569            // Careful: Lots of arguments must match up with get_asm method.
2570            py::arg("binary") = false,
2571            py::arg("large_elements_limit") = py::none(),
2572            py::arg("enable_debug_info") = false,
2573            py::arg("pretty_debug_info") = false,
2574            py::arg("print_generic_op_form") = false,
2575            py::arg("use_local_scope") = false,
2576            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2577       .def(
2578           "verify",
2579           [](PyOperationBase &self) {
2580             return mlirOperationVerify(self.getOperation());
2581           },
2582           "Verify the operation and return true if it passes, false if it "
2583           "fails.")
2584       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2585            "Puts self immediately after the other operation in its parent "
2586            "block.")
2587       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2588            "Puts self immediately before the other operation in its parent "
2589            "block.")
2590       .def(
2591           "detach_from_parent",
2592           [](PyOperationBase &self) {
2593             PyOperation &operation = self.getOperation();
2594             operation.checkValid();
2595             if (!operation.isAttached())
2596               throw py::value_error("Detached operation has no parent.");
2597 
2598             operation.detachFromParent();
2599             return operation.createOpView();
2600           },
2601           "Detaches the operation from its parent block.");
2602 
2603   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2604       .def_static("create", &PyOperation::create, py::arg("name"),
2605                   py::arg("results") = py::none(),
2606                   py::arg("operands") = py::none(),
2607                   py::arg("attributes") = py::none(),
2608                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2609                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2610                   kOperationCreateDocstring)
2611       .def_property_readonly("parent",
2612                              [](PyOperation &self) -> py::object {
2613                                auto parent = self.getParentOperation();
2614                                if (parent)
2615                                  return parent->getObject();
2616                                return py::none();
2617                              })
2618       .def("erase", &PyOperation::erase)
2619       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2620                              &PyOperation::getCapsule)
2621       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2622       .def_property_readonly("name",
2623                              [](PyOperation &self) {
2624                                self.checkValid();
2625                                MlirOperation operation = self.get();
2626                                MlirStringRef name = mlirIdentifierStr(
2627                                    mlirOperationGetName(operation));
2628                                return py::str(name.data, name.length);
2629                              })
2630       .def_property_readonly(
2631           "context",
2632           [](PyOperation &self) {
2633             self.checkValid();
2634             return self.getContext().getObject();
2635           },
2636           "Context that owns the Operation")
2637       .def_property_readonly("opview", &PyOperation::createOpView);
2638 
2639   auto opViewClass =
2640       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2641           .def(py::init<py::object>(), py::arg("operation"))
2642           .def_property_readonly("operation", &PyOpView::getOperationObject)
2643           .def_property_readonly(
2644               "context",
2645               [](PyOpView &self) {
2646                 return self.getOperation().getContext().getObject();
2647               },
2648               "Context that owns the Operation")
2649           .def("__str__", [](PyOpView &self) {
2650             return py::str(self.getOperationObject());
2651           });
2652   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2653   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2654   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2655   opViewClass.attr("build_generic") = classmethod(
2656       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2657       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2658       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2659       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2660       "Builds a specific, generated OpView based on class level attributes.");
2661 
2662   //----------------------------------------------------------------------------
2663   // Mapping of PyRegion.
2664   //----------------------------------------------------------------------------
2665   py::class_<PyRegion>(m, "Region", py::module_local())
2666       .def_property_readonly(
2667           "blocks",
2668           [](PyRegion &self) {
2669             return PyBlockList(self.getParentOperation(), self.get());
2670           },
2671           "Returns a forward-optimized sequence of blocks.")
2672       .def_property_readonly(
2673           "owner",
2674           [](PyRegion &self) {
2675             return self.getParentOperation()->createOpView();
2676           },
2677           "Returns the operation owning this region.")
2678       .def(
2679           "__iter__",
2680           [](PyRegion &self) {
2681             self.checkValid();
2682             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2683             return PyBlockIterator(self.getParentOperation(), firstBlock);
2684           },
2685           "Iterates over blocks in the region.")
2686       .def("__eq__",
2687            [](PyRegion &self, PyRegion &other) {
2688              return self.get().ptr == other.get().ptr;
2689            })
2690       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2691 
2692   //----------------------------------------------------------------------------
2693   // Mapping of PyBlock.
2694   //----------------------------------------------------------------------------
2695   py::class_<PyBlock>(m, "Block", py::module_local())
2696       .def_property_readonly(
2697           "owner",
2698           [](PyBlock &self) {
2699             return self.getParentOperation()->createOpView();
2700           },
2701           "Returns the owning operation of this block.")
2702       .def_property_readonly(
2703           "region",
2704           [](PyBlock &self) {
2705             MlirRegion region = mlirBlockGetParentRegion(self.get());
2706             return PyRegion(self.getParentOperation(), region);
2707           },
2708           "Returns the owning region of this block.")
2709       .def_property_readonly(
2710           "arguments",
2711           [](PyBlock &self) {
2712             return PyBlockArgumentList(self.getParentOperation(), self.get());
2713           },
2714           "Returns a list of block arguments.")
2715       .def_property_readonly(
2716           "operations",
2717           [](PyBlock &self) {
2718             return PyOperationList(self.getParentOperation(), self.get());
2719           },
2720           "Returns a forward-optimized sequence of operations.")
2721       .def_static(
2722           "create_at_start",
2723           [](PyRegion &parent, py::list pyArgTypes) {
2724             parent.checkValid();
2725             llvm::SmallVector<MlirType, 4> argTypes;
2726             llvm::SmallVector<MlirLocation, 4> argLocs;
2727             argTypes.reserve(pyArgTypes.size());
2728             argLocs.reserve(pyArgTypes.size());
2729             for (auto &pyArg : pyArgTypes) {
2730               argTypes.push_back(pyArg.cast<PyType &>());
2731               // TODO: Pass in a proper location here.
2732               argLocs.push_back(
2733                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2734             }
2735 
2736             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2737                                               argLocs.data());
2738             mlirRegionInsertOwnedBlock(parent, 0, block);
2739             return PyBlock(parent.getParentOperation(), block);
2740           },
2741           py::arg("parent"), py::arg("arg_types") = py::list(),
2742           "Creates and returns a new Block at the beginning of the given "
2743           "region (with given argument types).")
2744       .def(
2745           "create_before",
2746           [](PyBlock &self, py::args pyArgTypes) {
2747             self.checkValid();
2748             llvm::SmallVector<MlirType, 4> argTypes;
2749             llvm::SmallVector<MlirLocation, 4> argLocs;
2750             argTypes.reserve(pyArgTypes.size());
2751             argLocs.reserve(pyArgTypes.size());
2752             for (auto &pyArg : pyArgTypes) {
2753               argTypes.push_back(pyArg.cast<PyType &>());
2754               // TODO: Pass in a proper location here.
2755               argLocs.push_back(
2756                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2757             }
2758 
2759             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2760                                               argLocs.data());
2761             MlirRegion region = mlirBlockGetParentRegion(self.get());
2762             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2763             return PyBlock(self.getParentOperation(), block);
2764           },
2765           "Creates and returns a new Block before this block "
2766           "(with given argument types).")
2767       .def(
2768           "create_after",
2769           [](PyBlock &self, py::args pyArgTypes) {
2770             self.checkValid();
2771             llvm::SmallVector<MlirType, 4> argTypes;
2772             llvm::SmallVector<MlirLocation, 4> argLocs;
2773             argTypes.reserve(pyArgTypes.size());
2774             argLocs.reserve(pyArgTypes.size());
2775             for (auto &pyArg : pyArgTypes) {
2776               argTypes.push_back(pyArg.cast<PyType &>());
2777 
2778               // TODO: Pass in a proper location here.
2779               argLocs.push_back(
2780                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2781             }
2782             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2783                                               argLocs.data());
2784             MlirRegion region = mlirBlockGetParentRegion(self.get());
2785             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2786             return PyBlock(self.getParentOperation(), block);
2787           },
2788           "Creates and returns a new Block after this block "
2789           "(with given argument types).")
2790       .def(
2791           "__iter__",
2792           [](PyBlock &self) {
2793             self.checkValid();
2794             MlirOperation firstOperation =
2795                 mlirBlockGetFirstOperation(self.get());
2796             return PyOperationIterator(self.getParentOperation(),
2797                                        firstOperation);
2798           },
2799           "Iterates over operations in the block.")
2800       .def("__eq__",
2801            [](PyBlock &self, PyBlock &other) {
2802              return self.get().ptr == other.get().ptr;
2803            })
2804       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2805       .def(
2806           "__str__",
2807           [](PyBlock &self) {
2808             self.checkValid();
2809             PyPrintAccumulator printAccum;
2810             mlirBlockPrint(self.get(), printAccum.getCallback(),
2811                            printAccum.getUserData());
2812             return printAccum.join();
2813           },
2814           "Returns the assembly form of the block.")
2815       .def(
2816           "append",
2817           [](PyBlock &self, PyOperationBase &operation) {
2818             if (operation.getOperation().isAttached())
2819               operation.getOperation().detachFromParent();
2820 
2821             MlirOperation mlirOperation = operation.getOperation().get();
2822             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2823             operation.getOperation().setAttached(
2824                 self.getParentOperation().getObject());
2825           },
2826           py::arg("operation"),
2827           "Appends an operation to this block. If the operation is currently "
2828           "in another block, it will be moved.");
2829 
2830   //----------------------------------------------------------------------------
2831   // Mapping of PyInsertionPoint.
2832   //----------------------------------------------------------------------------
2833 
2834   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2835       .def(py::init<PyBlock &>(), py::arg("block"),
2836            "Inserts after the last operation but still inside the block.")
2837       .def("__enter__", &PyInsertionPoint::contextEnter)
2838       .def("__exit__", &PyInsertionPoint::contextExit)
2839       .def_property_readonly_static(
2840           "current",
2841           [](py::object & /*class*/) {
2842             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2843             if (!ip)
2844               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2845             return ip;
2846           },
2847           "Gets the InsertionPoint bound to the current thread or raises "
2848           "ValueError if none has been set")
2849       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2850            "Inserts before a referenced operation.")
2851       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2852                   py::arg("block"), "Inserts at the beginning of the block.")
2853       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2854                   py::arg("block"), "Inserts before the block terminator.")
2855       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2856            "Inserts an operation.")
2857       .def_property_readonly(
2858           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2859           "Returns the block that this InsertionPoint points to.");
2860 
2861   //----------------------------------------------------------------------------
2862   // Mapping of PyAttribute.
2863   //----------------------------------------------------------------------------
2864   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2865       // Delegate to the PyAttribute copy constructor, which will also lifetime
2866       // extend the backing context which owns the MlirAttribute.
2867       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2868            "Casts the passed attribute to the generic Attribute")
2869       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2870                              &PyAttribute::getCapsule)
2871       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2872       .def_static(
2873           "parse",
2874           [](std::string attrSpec, DefaultingPyMlirContext context) {
2875             MlirAttribute type = mlirAttributeParseGet(
2876                 context->get(), toMlirStringRef(attrSpec));
2877             // TODO: Rework error reporting once diagnostic engine is exposed
2878             // in C API.
2879             if (mlirAttributeIsNull(type)) {
2880               throw SetPyError(PyExc_ValueError,
2881                                Twine("Unable to parse attribute: '") +
2882                                    attrSpec + "'");
2883             }
2884             return PyAttribute(context->getRef(), type);
2885           },
2886           py::arg("asm"), py::arg("context") = py::none(),
2887           "Parses an attribute from an assembly form")
2888       .def_property_readonly(
2889           "context",
2890           [](PyAttribute &self) { return self.getContext().getObject(); },
2891           "Context that owns the Attribute")
2892       .def_property_readonly("type",
2893                              [](PyAttribute &self) {
2894                                return PyType(self.getContext()->getRef(),
2895                                              mlirAttributeGetType(self));
2896                              })
2897       .def(
2898           "get_named",
2899           [](PyAttribute &self, std::string name) {
2900             return PyNamedAttribute(self, std::move(name));
2901           },
2902           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2903       .def("__eq__",
2904            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2905       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2906       .def("__hash__",
2907            [](PyAttribute &self) {
2908              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2909            })
2910       .def(
2911           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2912           kDumpDocstring)
2913       .def(
2914           "__str__",
2915           [](PyAttribute &self) {
2916             PyPrintAccumulator printAccum;
2917             mlirAttributePrint(self, printAccum.getCallback(),
2918                                printAccum.getUserData());
2919             return printAccum.join();
2920           },
2921           "Returns the assembly form of the Attribute.")
2922       .def("__repr__", [](PyAttribute &self) {
2923         // Generally, assembly formats are not printed for __repr__ because
2924         // this can cause exceptionally long debug output and exceptions.
2925         // However, attribute values are generally considered useful and are
2926         // printed. This may need to be re-evaluated if debug dumps end up
2927         // being excessive.
2928         PyPrintAccumulator printAccum;
2929         printAccum.parts.append("Attribute(");
2930         mlirAttributePrint(self, printAccum.getCallback(),
2931                            printAccum.getUserData());
2932         printAccum.parts.append(")");
2933         return printAccum.join();
2934       });
2935 
2936   //----------------------------------------------------------------------------
2937   // Mapping of PyNamedAttribute
2938   //----------------------------------------------------------------------------
2939   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2940       .def("__repr__",
2941            [](PyNamedAttribute &self) {
2942              PyPrintAccumulator printAccum;
2943              printAccum.parts.append("NamedAttribute(");
2944              printAccum.parts.append(
2945                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2946                          mlirIdentifierStr(self.namedAttr.name).length));
2947              printAccum.parts.append("=");
2948              mlirAttributePrint(self.namedAttr.attribute,
2949                                 printAccum.getCallback(),
2950                                 printAccum.getUserData());
2951              printAccum.parts.append(")");
2952              return printAccum.join();
2953            })
2954       .def_property_readonly(
2955           "name",
2956           [](PyNamedAttribute &self) {
2957             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2958                            mlirIdentifierStr(self.namedAttr.name).length);
2959           },
2960           "The name of the NamedAttribute binding")
2961       .def_property_readonly(
2962           "attr",
2963           [](PyNamedAttribute &self) {
2964             // TODO: When named attribute is removed/refactored, also remove
2965             // this constructor (it does an inefficient table lookup).
2966             auto contextRef = PyMlirContext::forContext(
2967                 mlirAttributeGetContext(self.namedAttr.attribute));
2968             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2969           },
2970           py::keep_alive<0, 1>(),
2971           "The underlying generic attribute of the NamedAttribute binding");
2972 
2973   //----------------------------------------------------------------------------
2974   // Mapping of PyType.
2975   //----------------------------------------------------------------------------
2976   py::class_<PyType>(m, "Type", py::module_local())
2977       // Delegate to the PyType copy constructor, which will also lifetime
2978       // extend the backing context which owns the MlirType.
2979       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2980            "Casts the passed type to the generic Type")
2981       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2982       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2983       .def_static(
2984           "parse",
2985           [](std::string typeSpec, DefaultingPyMlirContext context) {
2986             MlirType type =
2987                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2988             // TODO: Rework error reporting once diagnostic engine is exposed
2989             // in C API.
2990             if (mlirTypeIsNull(type)) {
2991               throw SetPyError(PyExc_ValueError,
2992                                Twine("Unable to parse type: '") + typeSpec +
2993                                    "'");
2994             }
2995             return PyType(context->getRef(), type);
2996           },
2997           py::arg("asm"), py::arg("context") = py::none(),
2998           kContextParseTypeDocstring)
2999       .def_property_readonly(
3000           "context", [](PyType &self) { return self.getContext().getObject(); },
3001           "Context that owns the Type")
3002       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3003       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3004       .def("__hash__",
3005            [](PyType &self) {
3006              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3007            })
3008       .def(
3009           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3010       .def(
3011           "__str__",
3012           [](PyType &self) {
3013             PyPrintAccumulator printAccum;
3014             mlirTypePrint(self, printAccum.getCallback(),
3015                           printAccum.getUserData());
3016             return printAccum.join();
3017           },
3018           "Returns the assembly form of the type.")
3019       .def("__repr__", [](PyType &self) {
3020         // Generally, assembly formats are not printed for __repr__ because
3021         // this can cause exceptionally long debug output and exceptions.
3022         // However, types are an exception as they typically have compact
3023         // assembly forms and printing them is useful.
3024         PyPrintAccumulator printAccum;
3025         printAccum.parts.append("Type(");
3026         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3027         printAccum.parts.append(")");
3028         return printAccum.join();
3029       });
3030 
3031   //----------------------------------------------------------------------------
3032   // Mapping of Value.
3033   //----------------------------------------------------------------------------
3034   py::class_<PyValue>(m, "Value", py::module_local())
3035       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3036       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3037       .def_property_readonly(
3038           "context",
3039           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3040           "Context in which the value lives.")
3041       .def(
3042           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3043           kDumpDocstring)
3044       .def_property_readonly(
3045           "owner",
3046           [](PyValue &self) {
3047             assert(mlirOperationEqual(self.getParentOperation()->get(),
3048                                       mlirOpResultGetOwner(self.get())) &&
3049                    "expected the owner of the value in Python to match that in "
3050                    "the IR");
3051             return self.getParentOperation().getObject();
3052           })
3053       .def("__eq__",
3054            [](PyValue &self, PyValue &other) {
3055              return self.get().ptr == other.get().ptr;
3056            })
3057       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3058       .def("__hash__",
3059            [](PyValue &self) {
3060              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3061            })
3062       .def(
3063           "__str__",
3064           [](PyValue &self) {
3065             PyPrintAccumulator printAccum;
3066             printAccum.parts.append("Value(");
3067             mlirValuePrint(self.get(), printAccum.getCallback(),
3068                            printAccum.getUserData());
3069             printAccum.parts.append(")");
3070             return printAccum.join();
3071           },
3072           kValueDunderStrDocstring)
3073       .def_property_readonly("type", [](PyValue &self) {
3074         return PyType(self.getParentOperation()->getContext(),
3075                       mlirValueGetType(self.get()));
3076       });
3077   PyBlockArgument::bind(m);
3078   PyOpResult::bind(m);
3079 
3080   //----------------------------------------------------------------------------
3081   // Mapping of SymbolTable.
3082   //----------------------------------------------------------------------------
3083   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3084       .def(py::init<PyOperationBase &>())
3085       .def("__getitem__", &PySymbolTable::dunderGetItem)
3086       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3087       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3088       .def("__delitem__", &PySymbolTable::dunderDel)
3089       .def("__contains__",
3090            [](PySymbolTable &table, const std::string &name) {
3091              return !mlirOperationIsNull(mlirSymbolTableLookup(
3092                  table, mlirStringRefCreate(name.data(), name.length())));
3093            })
3094       // Static helpers.
3095       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3096                   py::arg("symbol"), py::arg("name"))
3097       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3098                   py::arg("symbol"))
3099       .def_static("get_visibility", &PySymbolTable::getVisibility,
3100                   py::arg("symbol"))
3101       .def_static("set_visibility", &PySymbolTable::setVisibility,
3102                   py::arg("symbol"), py::arg("visibility"))
3103       .def_static("replace_all_symbol_uses",
3104                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3105                   py::arg("new_symbol"), py::arg("from_op"))
3106       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3107                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3108                   py::arg("callback"));
3109 
3110   // Container bindings.
3111   PyBlockArgumentList::bind(m);
3112   PyBlockIterator::bind(m);
3113   PyBlockList::bind(m);
3114   PyOperationIterator::bind(m);
3115   PyOperationList::bind(m);
3116   PyOpAttributeMap::bind(m);
3117   PyOpOperandList::bind(m);
3118   PyOpResultList::bind(m);
3119   PyRegionIterator::bind(m);
3120   PyRegionList::bind(m);
3121 
3122   // Debug bindings.
3123   PyGlobalDebugFlag::bind(m);
3124 }
3125