1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 #include <utility>
25 
26 namespace py = pybind11;
27 using namespace mlir;
28 using namespace mlir::python;
29 
30 using llvm::SmallVector;
31 using llvm::StringRef;
32 using llvm::Twine;
33 
34 //------------------------------------------------------------------------------
35 // Docstrings (trivial, non-duplicated docstrings are included inline).
36 //------------------------------------------------------------------------------
37 
38 static const char kContextParseTypeDocstring[] =
39     R"(Parses the assembly form of a type.
40 
41 Returns a Type object or raises a ValueError if the type cannot be parsed.
42 
43 See also: https://mlir.llvm.org/docs/LangRef/#type-system
44 )";
45 
46 static const char kContextGetCallSiteLocationDocstring[] =
47     R"(Gets a Location representing a caller and callsite)";
48 
49 static const char kContextGetFileLocationDocstring[] =
50     R"(Gets a Location representing a file, line and column)";
51 
52 static const char kContextGetFusedLocationDocstring[] =
53     R"(Gets a Location representing a fused location with optional metadata)";
54 
55 static const char kContextGetNameLocationDocString[] =
56     R"(Gets a Location representing a named location with optional child location)";
57 
58 static const char kModuleParseDocstring[] =
59     R"(Parses a module's assembly format from a string.
60 
61 Returns a new MlirModule or raises a ValueError if the parsing fails.
62 
63 See also: https://mlir.llvm.org/docs/LangRef/
64 )";
65 
66 static const char kOperationCreateDocstring[] =
67     R"(Creates a new operation.
68 
69 Args:
70   name: Operation name (e.g. "dialect.operation").
71   results: Sequence of Type representing op result types.
72   attributes: Dict of str:Attribute.
73   successors: List of Block for the operation's successors.
74   regions: Number of regions to create.
75   location: A Location object (defaults to resolve from context manager).
76   ip: An InsertionPoint (defaults to resolve from context manager or set to
77     False to disable insertion, even with an insertion point set in the
78     context manager).
79 Returns:
80   A new "detached" Operation object. Detached operations can be added
81   to blocks, which causes them to become "attached."
82 )";
83 
84 static const char kOperationPrintDocstring[] =
85     R"(Prints the assembly form of the operation to a file like object.
86 
87 Args:
88   file: The file like object to write to. Defaults to sys.stdout.
89   binary: Whether to write bytes (True) or str (False). Defaults to False.
90   large_elements_limit: Whether to elide elements attributes above this
91     number of elements. Defaults to None (no limit).
92   enable_debug_info: Whether to print debug/location information. Defaults
93     to False.
94   pretty_debug_info: Whether to format debug information for easier reading
95     by a human (warning: the result is unparseable).
96   print_generic_op_form: Whether to print the generic assembly forms of all
97     ops. Defaults to False.
98   use_local_Scope: Whether to print in a way that is more optimized for
99     multi-threaded access but may not be consistent with how the overall
100     module prints.
101   assume_verified: By default, if not printing generic form, the verifier
102     will be run and if it fails, generic form will be printed with a comment
103     about failed verification. While a reasonable default for interactive use,
104     for systematic use, it is often better for the caller to verify explicitly
105     and report failures in a more robust fashion. Set this to True if doing this
106     in order to avoid running a redundant verification. If the IR is actually
107     invalid, behavior is undefined.
108 )";
109 
110 static const char kOperationGetAsmDocstring[] =
111     R"(Gets the assembly form of the operation with all options available.
112 
113 Args:
114   binary: Whether to return a bytes (True) or str (False) object. Defaults to
115     False.
116   ... others ...: See the print() method for common keyword arguments for
117     configuring the printout.
118 Returns:
119   Either a bytes or str object, depending on the setting of the 'binary'
120   argument.
121 )";
122 
123 static const char kOperationStrDunderDocstring[] =
124     R"(Gets the assembly form of the operation with default options.
125 
126 If more advanced control over the assembly formatting or I/O options is needed,
127 use the dedicated print or get_asm method, which supports keyword arguments to
128 customize behavior.
129 )";
130 
131 static const char kDumpDocstring[] =
132     R"(Dumps a debug representation of the object to stderr.)";
133 
134 static const char kAppendBlockDocstring[] =
135     R"(Appends a new block, with argument types as positional args.
136 
137 Returns:
138   The created block.
139 )";
140 
141 static const char kValueDunderStrDocstring[] =
142     R"(Returns the string form of the value.
143 
144 If the value is a block argument, this is the assembly form of its type and the
145 position in the argument list. If the value is an operation result, this is
146 equivalent to printing the operation that produced it.
147 )";
148 
149 //------------------------------------------------------------------------------
150 // Utilities.
151 //------------------------------------------------------------------------------
152 
153 /// Helper for creating an @classmethod.
154 template <class Func, typename... Args>
155 py::object classmethod(Func f, Args... args) {
156   py::object cf = py::cpp_function(f, args...);
157   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
158 }
159 
160 static py::object
161 createCustomDialectWrapper(const std::string &dialectNamespace,
162                            py::object dialectDescriptor) {
163   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
164   if (!dialectClass) {
165     // Use the base class.
166     return py::cast(PyDialect(std::move(dialectDescriptor)));
167   }
168 
169   // Create the custom implementation.
170   return (*dialectClass)(std::move(dialectDescriptor));
171 }
172 
173 static MlirStringRef toMlirStringRef(const std::string &s) {
174   return mlirStringRefCreate(s.data(), s.size());
175 }
176 
177 /// Wrapper for the global LLVM debugging flag.
178 struct PyGlobalDebugFlag {
179   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
180 
181   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
182 
183   static void bind(py::module &m) {
184     // Debug flags.
185     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
186         .def_property_static("flag", &PyGlobalDebugFlag::get,
187                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
188   }
189 };
190 
191 //------------------------------------------------------------------------------
192 // Collections.
193 //------------------------------------------------------------------------------
194 
195 namespace {
196 
197 class PyRegionIterator {
198 public:
199   PyRegionIterator(PyOperationRef operation)
200       : operation(std::move(operation)) {}
201 
202   PyRegionIterator &dunderIter() { return *this; }
203 
204   PyRegion dunderNext() {
205     operation->checkValid();
206     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
207       throw py::stop_iteration();
208     }
209     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
210     return PyRegion(operation, region);
211   }
212 
213   static void bind(py::module &m) {
214     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
215         .def("__iter__", &PyRegionIterator::dunderIter)
216         .def("__next__", &PyRegionIterator::dunderNext);
217   }
218 
219 private:
220   PyOperationRef operation;
221   int nextIndex = 0;
222 };
223 
224 /// Regions of an op are fixed length and indexed numerically so are represented
225 /// with a sequence-like container.
226 class PyRegionList {
227 public:
228   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
229 
230   intptr_t dunderLen() {
231     operation->checkValid();
232     return mlirOperationGetNumRegions(operation->get());
233   }
234 
235   PyRegion dunderGetItem(intptr_t index) {
236     // dunderLen checks validity.
237     if (index < 0 || index >= dunderLen()) {
238       throw SetPyError(PyExc_IndexError,
239                        "attempt to access out of bounds region");
240     }
241     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
242     return PyRegion(operation, region);
243   }
244 
245   static void bind(py::module &m) {
246     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
247         .def("__len__", &PyRegionList::dunderLen)
248         .def("__getitem__", &PyRegionList::dunderGetItem);
249   }
250 
251 private:
252   PyOperationRef operation;
253 };
254 
255 class PyBlockIterator {
256 public:
257   PyBlockIterator(PyOperationRef operation, MlirBlock next)
258       : operation(std::move(operation)), next(next) {}
259 
260   PyBlockIterator &dunderIter() { return *this; }
261 
262   PyBlock dunderNext() {
263     operation->checkValid();
264     if (mlirBlockIsNull(next)) {
265       throw py::stop_iteration();
266     }
267 
268     PyBlock returnBlock(operation, next);
269     next = mlirBlockGetNextInRegion(next);
270     return returnBlock;
271   }
272 
273   static void bind(py::module &m) {
274     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
275         .def("__iter__", &PyBlockIterator::dunderIter)
276         .def("__next__", &PyBlockIterator::dunderNext);
277   }
278 
279 private:
280   PyOperationRef operation;
281   MlirBlock next;
282 };
283 
284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
285 /// we present them as a more full-featured list-like container but optimize
286 /// it for forward iteration. Blocks are always owned by a region.
287 class PyBlockList {
288 public:
289   PyBlockList(PyOperationRef operation, MlirRegion region)
290       : operation(std::move(operation)), region(region) {}
291 
292   PyBlockIterator dunderIter() {
293     operation->checkValid();
294     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
295   }
296 
297   intptr_t dunderLen() {
298     operation->checkValid();
299     intptr_t count = 0;
300     MlirBlock block = mlirRegionGetFirstBlock(region);
301     while (!mlirBlockIsNull(block)) {
302       count += 1;
303       block = mlirBlockGetNextInRegion(block);
304     }
305     return count;
306   }
307 
308   PyBlock dunderGetItem(intptr_t index) {
309     operation->checkValid();
310     if (index < 0) {
311       throw SetPyError(PyExc_IndexError,
312                        "attempt to access out of bounds block");
313     }
314     MlirBlock block = mlirRegionGetFirstBlock(region);
315     while (!mlirBlockIsNull(block)) {
316       if (index == 0) {
317         return PyBlock(operation, block);
318       }
319       block = mlirBlockGetNextInRegion(block);
320       index -= 1;
321     }
322     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
323   }
324 
325   PyBlock appendBlock(const py::args &pyArgTypes) {
326     operation->checkValid();
327     llvm::SmallVector<MlirType, 4> argTypes;
328     llvm::SmallVector<MlirLocation, 4> argLocs;
329     argTypes.reserve(pyArgTypes.size());
330     argLocs.reserve(pyArgTypes.size());
331     for (auto &pyArg : pyArgTypes) {
332       argTypes.push_back(pyArg.cast<PyType &>());
333       // TODO: Pass in a proper location here.
334       argLocs.push_back(
335           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
336     }
337 
338     MlirBlock block =
339         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
340     mlirRegionAppendOwnedBlock(region, block);
341     return PyBlock(operation, block);
342   }
343 
344   static void bind(py::module &m) {
345     py::class_<PyBlockList>(m, "BlockList", py::module_local())
346         .def("__getitem__", &PyBlockList::dunderGetItem)
347         .def("__iter__", &PyBlockList::dunderIter)
348         .def("__len__", &PyBlockList::dunderLen)
349         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
350   }
351 
352 private:
353   PyOperationRef operation;
354   MlirRegion region;
355 };
356 
357 class PyOperationIterator {
358 public:
359   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
360       : parentOperation(std::move(parentOperation)), next(next) {}
361 
362   PyOperationIterator &dunderIter() { return *this; }
363 
364   py::object dunderNext() {
365     parentOperation->checkValid();
366     if (mlirOperationIsNull(next)) {
367       throw py::stop_iteration();
368     }
369 
370     PyOperationRef returnOperation =
371         PyOperation::forOperation(parentOperation->getContext(), next);
372     next = mlirOperationGetNextInBlock(next);
373     return returnOperation->createOpView();
374   }
375 
376   static void bind(py::module &m) {
377     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
378         .def("__iter__", &PyOperationIterator::dunderIter)
379         .def("__next__", &PyOperationIterator::dunderNext);
380   }
381 
382 private:
383   PyOperationRef parentOperation;
384   MlirOperation next;
385 };
386 
387 /// Operations are exposed by the C-API as a forward-only linked list. In
388 /// Python, we present them as a more full-featured list-like container but
389 /// optimize it for forward iteration. Iterable operations are always owned
390 /// by a block.
391 class PyOperationList {
392 public:
393   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
394       : parentOperation(std::move(parentOperation)), block(block) {}
395 
396   PyOperationIterator dunderIter() {
397     parentOperation->checkValid();
398     return PyOperationIterator(parentOperation,
399                                mlirBlockGetFirstOperation(block));
400   }
401 
402   intptr_t dunderLen() {
403     parentOperation->checkValid();
404     intptr_t count = 0;
405     MlirOperation childOp = mlirBlockGetFirstOperation(block);
406     while (!mlirOperationIsNull(childOp)) {
407       count += 1;
408       childOp = mlirOperationGetNextInBlock(childOp);
409     }
410     return count;
411   }
412 
413   py::object dunderGetItem(intptr_t index) {
414     parentOperation->checkValid();
415     if (index < 0) {
416       throw SetPyError(PyExc_IndexError,
417                        "attempt to access out of bounds operation");
418     }
419     MlirOperation childOp = mlirBlockGetFirstOperation(block);
420     while (!mlirOperationIsNull(childOp)) {
421       if (index == 0) {
422         return PyOperation::forOperation(parentOperation->getContext(), childOp)
423             ->createOpView();
424       }
425       childOp = mlirOperationGetNextInBlock(childOp);
426       index -= 1;
427     }
428     throw SetPyError(PyExc_IndexError,
429                      "attempt to access out of bounds operation");
430   }
431 
432   static void bind(py::module &m) {
433     py::class_<PyOperationList>(m, "OperationList", py::module_local())
434         .def("__getitem__", &PyOperationList::dunderGetItem)
435         .def("__iter__", &PyOperationList::dunderIter)
436         .def("__len__", &PyOperationList::dunderLen);
437   }
438 
439 private:
440   PyOperationRef parentOperation;
441   MlirBlock block;
442 };
443 
444 } // namespace
445 
446 //------------------------------------------------------------------------------
447 // PyMlirContext
448 //------------------------------------------------------------------------------
449 
450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
451   py::gil_scoped_acquire acquire;
452   auto &liveContexts = getLiveContexts();
453   liveContexts[context.ptr] = this;
454 }
455 
456 PyMlirContext::~PyMlirContext() {
457   // Note that the only public way to construct an instance is via the
458   // forContext method, which always puts the associated handle into
459   // liveContexts.
460   py::gil_scoped_acquire acquire;
461   getLiveContexts().erase(context.ptr);
462   mlirContextDestroy(context);
463 }
464 
465 py::object PyMlirContext::getCapsule() {
466   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
467 }
468 
469 py::object PyMlirContext::createFromCapsule(py::object capsule) {
470   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
471   if (mlirContextIsNull(rawContext))
472     throw py::error_already_set();
473   return forContext(rawContext).releaseObject();
474 }
475 
476 PyMlirContext *PyMlirContext::createNewContextForInit() {
477   MlirContext context = mlirContextCreate();
478   mlirRegisterAllDialects(context);
479   return new PyMlirContext(context);
480 }
481 
482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
483   py::gil_scoped_acquire acquire;
484   auto &liveContexts = getLiveContexts();
485   auto it = liveContexts.find(context.ptr);
486   if (it == liveContexts.end()) {
487     // Create.
488     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
489     py::object pyRef = py::cast(unownedContextWrapper);
490     assert(pyRef && "cast to py::object failed");
491     liveContexts[context.ptr] = unownedContextWrapper;
492     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
493   }
494   // Use existing.
495   py::object pyRef = py::cast(it->second);
496   return PyMlirContextRef(it->second, std::move(pyRef));
497 }
498 
499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
500   static LiveContextMap liveContexts;
501   return liveContexts;
502 }
503 
504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
505 
506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
507 
508 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
509 
510 pybind11::object PyMlirContext::contextEnter() {
511   return PyThreadContextEntry::pushContext(*this);
512 }
513 
514 void PyMlirContext::contextExit(const pybind11::object &excType,
515                                 const pybind11::object &excVal,
516                                 const pybind11::object &excTb) {
517   PyThreadContextEntry::popContext(*this);
518 }
519 
520 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
521   // Note that ownership is transferred to the delete callback below by way of
522   // an explicit inc_ref (borrow).
523   PyDiagnosticHandler *pyHandler =
524       new PyDiagnosticHandler(get(), std::move(callback));
525   py::object pyHandlerObject =
526       py::cast(pyHandler, py::return_value_policy::take_ownership);
527   pyHandlerObject.inc_ref();
528 
529   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
530   // guaranteed to be known to pybind.
531   auto handlerCallback =
532       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
533     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
534     py::object pyDiagnosticObject =
535         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
536 
537     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
538     bool result = false;
539     {
540       // Since this can be called from arbitrary C++ contexts, always get the
541       // gil.
542       py::gil_scoped_acquire gil;
543       try {
544         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
545       } catch (std::exception &e) {
546         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
547                 e.what());
548         pyHandler->hadError = true;
549       }
550     }
551 
552     pyDiagnostic->invalidate();
553     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
554   };
555   auto deleteCallback = +[](void *userData) {
556     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
557     assert(pyHandler->registeredID && "handler is not registered");
558     pyHandler->registeredID.reset();
559 
560     // Decrement reference, balancing the inc_ref() above.
561     py::object pyHandlerObject =
562         py::cast(pyHandler, py::return_value_policy::reference);
563     pyHandlerObject.dec_ref();
564   };
565 
566   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
567       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
568   return pyHandlerObject;
569 }
570 
571 PyMlirContext &DefaultingPyMlirContext::resolve() {
572   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
573   if (!context) {
574     throw SetPyError(
575         PyExc_RuntimeError,
576         "An MLIR function requires a Context but none was provided in the call "
577         "or from the surrounding environment. Either pass to the function with "
578         "a 'context=' argument or establish a default using 'with Context():'");
579   }
580   return *context;
581 }
582 
583 //------------------------------------------------------------------------------
584 // PyThreadContextEntry management
585 //------------------------------------------------------------------------------
586 
587 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
588   static thread_local std::vector<PyThreadContextEntry> stack;
589   return stack;
590 }
591 
592 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
593   auto &stack = getStack();
594   if (stack.empty())
595     return nullptr;
596   return &stack.back();
597 }
598 
599 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
600                                 py::object insertionPoint,
601                                 py::object location) {
602   auto &stack = getStack();
603   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
604                      std::move(location));
605   // If the new stack has more than one entry and the context of the new top
606   // entry matches the previous, copy the insertionPoint and location from the
607   // previous entry if missing from the new top entry.
608   if (stack.size() > 1) {
609     auto &prev = *(stack.rbegin() + 1);
610     auto &current = stack.back();
611     if (current.context.is(prev.context)) {
612       // Default non-context objects from the previous entry.
613       if (!current.insertionPoint)
614         current.insertionPoint = prev.insertionPoint;
615       if (!current.location)
616         current.location = prev.location;
617     }
618   }
619 }
620 
621 PyMlirContext *PyThreadContextEntry::getContext() {
622   if (!context)
623     return nullptr;
624   return py::cast<PyMlirContext *>(context);
625 }
626 
627 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
628   if (!insertionPoint)
629     return nullptr;
630   return py::cast<PyInsertionPoint *>(insertionPoint);
631 }
632 
633 PyLocation *PyThreadContextEntry::getLocation() {
634   if (!location)
635     return nullptr;
636   return py::cast<PyLocation *>(location);
637 }
638 
639 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
640   auto *tos = getTopOfStack();
641   return tos ? tos->getContext() : nullptr;
642 }
643 
644 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
645   auto *tos = getTopOfStack();
646   return tos ? tos->getInsertionPoint() : nullptr;
647 }
648 
649 PyLocation *PyThreadContextEntry::getDefaultLocation() {
650   auto *tos = getTopOfStack();
651   return tos ? tos->getLocation() : nullptr;
652 }
653 
654 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
655   py::object contextObj = py::cast(context);
656   push(FrameKind::Context, /*context=*/contextObj,
657        /*insertionPoint=*/py::object(),
658        /*location=*/py::object());
659   return contextObj;
660 }
661 
662 void PyThreadContextEntry::popContext(PyMlirContext &context) {
663   auto &stack = getStack();
664   if (stack.empty())
665     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
666   auto &tos = stack.back();
667   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
668     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
669   stack.pop_back();
670 }
671 
672 py::object
673 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
674   py::object contextObj =
675       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
676   py::object insertionPointObj = py::cast(insertionPoint);
677   push(FrameKind::InsertionPoint,
678        /*context=*/contextObj,
679        /*insertionPoint=*/insertionPointObj,
680        /*location=*/py::object());
681   return insertionPointObj;
682 }
683 
684 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
685   auto &stack = getStack();
686   if (stack.empty())
687     throw SetPyError(PyExc_RuntimeError,
688                      "Unbalanced InsertionPoint enter/exit");
689   auto &tos = stack.back();
690   if (tos.frameKind != FrameKind::InsertionPoint &&
691       tos.getInsertionPoint() != &insertionPoint)
692     throw SetPyError(PyExc_RuntimeError,
693                      "Unbalanced InsertionPoint enter/exit");
694   stack.pop_back();
695 }
696 
697 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
698   py::object contextObj = location.getContext().getObject();
699   py::object locationObj = py::cast(location);
700   push(FrameKind::Location, /*context=*/contextObj,
701        /*insertionPoint=*/py::object(),
702        /*location=*/locationObj);
703   return locationObj;
704 }
705 
706 void PyThreadContextEntry::popLocation(PyLocation &location) {
707   auto &stack = getStack();
708   if (stack.empty())
709     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
710   auto &tos = stack.back();
711   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
712     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
713   stack.pop_back();
714 }
715 
716 //------------------------------------------------------------------------------
717 // PyDiagnostic*
718 //------------------------------------------------------------------------------
719 
720 void PyDiagnostic::invalidate() {
721   valid = false;
722   if (materializedNotes) {
723     for (auto &noteObject : *materializedNotes) {
724       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
725       note->invalidate();
726     }
727   }
728 }
729 
730 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
731                                          py::object callback)
732     : context(context), callback(std::move(callback)) {}
733 
734 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
735 
736 void PyDiagnosticHandler::detach() {
737   if (!registeredID)
738     return;
739   MlirDiagnosticHandlerID localID = *registeredID;
740   mlirContextDetachDiagnosticHandler(context, localID);
741   assert(!registeredID && "should have unregistered");
742   // Not strictly necessary but keeps stale pointers from being around to cause
743   // issues.
744   context = {nullptr};
745 }
746 
747 void PyDiagnostic::checkValid() {
748   if (!valid) {
749     throw std::invalid_argument(
750         "Diagnostic is invalid (used outside of callback)");
751   }
752 }
753 
754 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
755   checkValid();
756   return mlirDiagnosticGetSeverity(diagnostic);
757 }
758 
759 PyLocation PyDiagnostic::getLocation() {
760   checkValid();
761   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
762   MlirContext context = mlirLocationGetContext(loc);
763   return PyLocation(PyMlirContext::forContext(context), loc);
764 }
765 
766 py::str PyDiagnostic::getMessage() {
767   checkValid();
768   py::object fileObject = py::module::import("io").attr("StringIO")();
769   PyFileAccumulator accum(fileObject, /*binary=*/false);
770   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
771   return fileObject.attr("getvalue")();
772 }
773 
774 py::tuple PyDiagnostic::getNotes() {
775   checkValid();
776   if (materializedNotes)
777     return *materializedNotes;
778   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
779   materializedNotes = py::tuple(numNotes);
780   for (intptr_t i = 0; i < numNotes; ++i) {
781     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
782     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
783     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
784   }
785   return *materializedNotes;
786 }
787 
788 //------------------------------------------------------------------------------
789 // PyDialect, PyDialectDescriptor, PyDialects
790 //------------------------------------------------------------------------------
791 
792 MlirDialect PyDialects::getDialectForKey(const std::string &key,
793                                          bool attrError) {
794   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
795                                                     {key.data(), key.size()});
796   if (mlirDialectIsNull(dialect)) {
797     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
798                      Twine("Dialect '") + key + "' not found");
799   }
800   return dialect;
801 }
802 
803 //------------------------------------------------------------------------------
804 // PyLocation
805 //------------------------------------------------------------------------------
806 
807 py::object PyLocation::getCapsule() {
808   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
809 }
810 
811 PyLocation PyLocation::createFromCapsule(py::object capsule) {
812   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
813   if (mlirLocationIsNull(rawLoc))
814     throw py::error_already_set();
815   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
816                     rawLoc);
817 }
818 
819 py::object PyLocation::contextEnter() {
820   return PyThreadContextEntry::pushLocation(*this);
821 }
822 
823 void PyLocation::contextExit(const pybind11::object &excType,
824                              const pybind11::object &excVal,
825                              const pybind11::object &excTb) {
826   PyThreadContextEntry::popLocation(*this);
827 }
828 
829 PyLocation &DefaultingPyLocation::resolve() {
830   auto *location = PyThreadContextEntry::getDefaultLocation();
831   if (!location) {
832     throw SetPyError(
833         PyExc_RuntimeError,
834         "An MLIR function requires a Location but none was provided in the "
835         "call or from the surrounding environment. Either pass to the function "
836         "with a 'loc=' argument or establish a default using 'with loc:'");
837   }
838   return *location;
839 }
840 
841 //------------------------------------------------------------------------------
842 // PyModule
843 //------------------------------------------------------------------------------
844 
845 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
846     : BaseContextObject(std::move(contextRef)), module(module) {}
847 
848 PyModule::~PyModule() {
849   py::gil_scoped_acquire acquire;
850   auto &liveModules = getContext()->liveModules;
851   assert(liveModules.count(module.ptr) == 1 &&
852          "destroying module not in live map");
853   liveModules.erase(module.ptr);
854   mlirModuleDestroy(module);
855 }
856 
857 PyModuleRef PyModule::forModule(MlirModule module) {
858   MlirContext context = mlirModuleGetContext(module);
859   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
860 
861   py::gil_scoped_acquire acquire;
862   auto &liveModules = contextRef->liveModules;
863   auto it = liveModules.find(module.ptr);
864   if (it == liveModules.end()) {
865     // Create.
866     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
867     // Note that the default return value policy on cast is automatic_reference,
868     // which does not take ownership (delete will not be called).
869     // Just be explicit.
870     py::object pyRef =
871         py::cast(unownedModule, py::return_value_policy::take_ownership);
872     unownedModule->handle = pyRef;
873     liveModules[module.ptr] =
874         std::make_pair(unownedModule->handle, unownedModule);
875     return PyModuleRef(unownedModule, std::move(pyRef));
876   }
877   // Use existing.
878   PyModule *existing = it->second.second;
879   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
880   return PyModuleRef(existing, std::move(pyRef));
881 }
882 
883 py::object PyModule::createFromCapsule(py::object capsule) {
884   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
885   if (mlirModuleIsNull(rawModule))
886     throw py::error_already_set();
887   return forModule(rawModule).releaseObject();
888 }
889 
890 py::object PyModule::getCapsule() {
891   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
892 }
893 
894 //------------------------------------------------------------------------------
895 // PyOperation
896 //------------------------------------------------------------------------------
897 
898 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
899     : BaseContextObject(std::move(contextRef)), operation(operation) {}
900 
901 PyOperation::~PyOperation() {
902   // If the operation has already been invalidated there is nothing to do.
903   if (!valid)
904     return;
905   auto &liveOperations = getContext()->liveOperations;
906   assert(liveOperations.count(operation.ptr) == 1 &&
907          "destroying operation not in live map");
908   liveOperations.erase(operation.ptr);
909   if (!isAttached()) {
910     mlirOperationDestroy(operation);
911   }
912 }
913 
914 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
915                                            MlirOperation operation,
916                                            py::object parentKeepAlive) {
917   auto &liveOperations = contextRef->liveOperations;
918   // Create.
919   PyOperation *unownedOperation =
920       new PyOperation(std::move(contextRef), operation);
921   // Note that the default return value policy on cast is automatic_reference,
922   // which does not take ownership (delete will not be called).
923   // Just be explicit.
924   py::object pyRef =
925       py::cast(unownedOperation, py::return_value_policy::take_ownership);
926   unownedOperation->handle = pyRef;
927   if (parentKeepAlive) {
928     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
929   }
930   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
931   return PyOperationRef(unownedOperation, std::move(pyRef));
932 }
933 
934 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
935                                          MlirOperation operation,
936                                          py::object parentKeepAlive) {
937   auto &liveOperations = contextRef->liveOperations;
938   auto it = liveOperations.find(operation.ptr);
939   if (it == liveOperations.end()) {
940     // Create.
941     return createInstance(std::move(contextRef), operation,
942                           std::move(parentKeepAlive));
943   }
944   // Use existing.
945   PyOperation *existing = it->second.second;
946   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
947   return PyOperationRef(existing, std::move(pyRef));
948 }
949 
950 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
951                                            MlirOperation operation,
952                                            py::object parentKeepAlive) {
953   auto &liveOperations = contextRef->liveOperations;
954   assert(liveOperations.count(operation.ptr) == 0 &&
955          "cannot create detached operation that already exists");
956   (void)liveOperations;
957 
958   PyOperationRef created = createInstance(std::move(contextRef), operation,
959                                           std::move(parentKeepAlive));
960   created->attached = false;
961   return created;
962 }
963 
964 void PyOperation::checkValid() const {
965   if (!valid) {
966     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
967   }
968 }
969 
970 void PyOperationBase::print(py::object fileObject, bool binary,
971                             llvm::Optional<int64_t> largeElementsLimit,
972                             bool enableDebugInfo, bool prettyDebugInfo,
973                             bool printGenericOpForm, bool useLocalScope,
974                             bool assumeVerified) {
975   PyOperation &operation = getOperation();
976   operation.checkValid();
977   if (fileObject.is_none())
978     fileObject = py::module::import("sys").attr("stdout");
979 
980   if (!assumeVerified && !printGenericOpForm &&
981       !mlirOperationVerify(operation)) {
982     std::string message("// Verification failed, printing generic form\n");
983     if (binary) {
984       fileObject.attr("write")(py::bytes(message));
985     } else {
986       fileObject.attr("write")(py::str(message));
987     }
988     printGenericOpForm = true;
989   }
990 
991   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
992   if (largeElementsLimit)
993     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
994   if (enableDebugInfo)
995     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
996   if (printGenericOpForm)
997     mlirOpPrintingFlagsPrintGenericOpForm(flags);
998 
999   PyFileAccumulator accum(fileObject, binary);
1000   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1001                               accum.getUserData());
1002   mlirOpPrintingFlagsDestroy(flags);
1003 }
1004 
1005 py::object PyOperationBase::getAsm(bool binary,
1006                                    llvm::Optional<int64_t> largeElementsLimit,
1007                                    bool enableDebugInfo, bool prettyDebugInfo,
1008                                    bool printGenericOpForm, bool useLocalScope,
1009                                    bool assumeVerified) {
1010   py::object fileObject;
1011   if (binary) {
1012     fileObject = py::module::import("io").attr("BytesIO")();
1013   } else {
1014     fileObject = py::module::import("io").attr("StringIO")();
1015   }
1016   print(fileObject, /*binary=*/binary,
1017         /*largeElementsLimit=*/largeElementsLimit,
1018         /*enableDebugInfo=*/enableDebugInfo,
1019         /*prettyDebugInfo=*/prettyDebugInfo,
1020         /*printGenericOpForm=*/printGenericOpForm,
1021         /*useLocalScope=*/useLocalScope,
1022         /*assumeVerified=*/assumeVerified);
1023 
1024   return fileObject.attr("getvalue")();
1025 }
1026 
1027 void PyOperationBase::moveAfter(PyOperationBase &other) {
1028   PyOperation &operation = getOperation();
1029   PyOperation &otherOp = other.getOperation();
1030   operation.checkValid();
1031   otherOp.checkValid();
1032   mlirOperationMoveAfter(operation, otherOp);
1033   operation.parentKeepAlive = otherOp.parentKeepAlive;
1034 }
1035 
1036 void PyOperationBase::moveBefore(PyOperationBase &other) {
1037   PyOperation &operation = getOperation();
1038   PyOperation &otherOp = other.getOperation();
1039   operation.checkValid();
1040   otherOp.checkValid();
1041   mlirOperationMoveBefore(operation, otherOp);
1042   operation.parentKeepAlive = otherOp.parentKeepAlive;
1043 }
1044 
1045 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1046   checkValid();
1047   if (!isAttached())
1048     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1049   MlirOperation operation = mlirOperationGetParentOperation(get());
1050   if (mlirOperationIsNull(operation))
1051     return {};
1052   return PyOperation::forOperation(getContext(), operation);
1053 }
1054 
1055 PyBlock PyOperation::getBlock() {
1056   checkValid();
1057   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1058   MlirBlock block = mlirOperationGetBlock(get());
1059   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1060   assert(parentOperation && "Operation has no parent");
1061   return PyBlock{std::move(*parentOperation), block};
1062 }
1063 
1064 py::object PyOperation::getCapsule() {
1065   checkValid();
1066   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1067 }
1068 
1069 py::object PyOperation::createFromCapsule(py::object capsule) {
1070   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1071   if (mlirOperationIsNull(rawOperation))
1072     throw py::error_already_set();
1073   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1074   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1075       .releaseObject();
1076 }
1077 
1078 static void maybeInsertOperation(PyOperationRef &op,
1079                                  const py::object &maybeIp) {
1080   // InsertPoint active?
1081   if (!maybeIp.is(py::cast(false))) {
1082     PyInsertionPoint *ip;
1083     if (maybeIp.is_none()) {
1084       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1085     } else {
1086       ip = py::cast<PyInsertionPoint *>(maybeIp);
1087     }
1088     if (ip)
1089       ip->insert(*op.get());
1090   }
1091 }
1092 
1093 py::object PyOperation::create(
1094     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1095     llvm::Optional<std::vector<PyValue *>> operands,
1096     llvm::Optional<py::dict> attributes,
1097     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1098     DefaultingPyLocation location, const py::object &maybeIp) {
1099   llvm::SmallVector<MlirValue, 4> mlirOperands;
1100   llvm::SmallVector<MlirType, 4> mlirResults;
1101   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1102   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1103 
1104   // General parameter validation.
1105   if (regions < 0)
1106     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1107 
1108   // Unpack/validate operands.
1109   if (operands) {
1110     mlirOperands.reserve(operands->size());
1111     for (PyValue *operand : *operands) {
1112       if (!operand)
1113         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1114       mlirOperands.push_back(operand->get());
1115     }
1116   }
1117 
1118   // Unpack/validate results.
1119   if (results) {
1120     mlirResults.reserve(results->size());
1121     for (PyType *result : *results) {
1122       // TODO: Verify result type originate from the same context.
1123       if (!result)
1124         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1125       mlirResults.push_back(*result);
1126     }
1127   }
1128   // Unpack/validate attributes.
1129   if (attributes) {
1130     mlirAttributes.reserve(attributes->size());
1131     for (auto &it : *attributes) {
1132       std::string key;
1133       try {
1134         key = it.first.cast<std::string>();
1135       } catch (py::cast_error &err) {
1136         std::string msg = "Invalid attribute key (not a string) when "
1137                           "attempting to create the operation \"" +
1138                           name + "\" (" + err.what() + ")";
1139         throw py::cast_error(msg);
1140       }
1141       try {
1142         auto &attribute = it.second.cast<PyAttribute &>();
1143         // TODO: Verify attribute originates from the same context.
1144         mlirAttributes.emplace_back(std::move(key), attribute);
1145       } catch (py::reference_cast_error &) {
1146         // This exception seems thrown when the value is "None".
1147         std::string msg =
1148             "Found an invalid (`None`?) attribute value for the key \"" + key +
1149             "\" when attempting to create the operation \"" + name + "\"";
1150         throw py::cast_error(msg);
1151       } catch (py::cast_error &err) {
1152         std::string msg = "Invalid attribute value for the key \"" + key +
1153                           "\" when attempting to create the operation \"" +
1154                           name + "\" (" + err.what() + ")";
1155         throw py::cast_error(msg);
1156       }
1157     }
1158   }
1159   // Unpack/validate successors.
1160   if (successors) {
1161     mlirSuccessors.reserve(successors->size());
1162     for (auto *successor : *successors) {
1163       // TODO: Verify successor originate from the same context.
1164       if (!successor)
1165         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1166       mlirSuccessors.push_back(successor->get());
1167     }
1168   }
1169 
1170   // Apply unpacked/validated to the operation state. Beyond this
1171   // point, exceptions cannot be thrown or else the state will leak.
1172   MlirOperationState state =
1173       mlirOperationStateGet(toMlirStringRef(name), location);
1174   if (!mlirOperands.empty())
1175     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1176                                   mlirOperands.data());
1177   if (!mlirResults.empty())
1178     mlirOperationStateAddResults(&state, mlirResults.size(),
1179                                  mlirResults.data());
1180   if (!mlirAttributes.empty()) {
1181     // Note that the attribute names directly reference bytes in
1182     // mlirAttributes, so that vector must not be changed from here
1183     // on.
1184     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1185     mlirNamedAttributes.reserve(mlirAttributes.size());
1186     for (auto &it : mlirAttributes)
1187       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1188           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1189                             toMlirStringRef(it.first)),
1190           it.second));
1191     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1192                                     mlirNamedAttributes.data());
1193   }
1194   if (!mlirSuccessors.empty())
1195     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1196                                     mlirSuccessors.data());
1197   if (regions) {
1198     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1199     mlirRegions.resize(regions);
1200     for (int i = 0; i < regions; ++i)
1201       mlirRegions[i] = mlirRegionCreate();
1202     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1203                                       mlirRegions.data());
1204   }
1205 
1206   // Construct the operation.
1207   MlirOperation operation = mlirOperationCreate(&state);
1208   PyOperationRef created =
1209       PyOperation::createDetached(location->getContext(), operation);
1210   maybeInsertOperation(created, maybeIp);
1211 
1212   return created->createOpView();
1213 }
1214 
1215 py::object PyOperation::clone(const py::object &maybeIp) {
1216   MlirOperation clonedOperation = mlirOperationClone(operation);
1217   PyOperationRef cloned =
1218       PyOperation::createDetached(getContext(), clonedOperation);
1219   maybeInsertOperation(cloned, maybeIp);
1220 
1221   return cloned->createOpView();
1222 }
1223 
1224 py::object PyOperation::createOpView() {
1225   checkValid();
1226   MlirIdentifier ident = mlirOperationGetName(get());
1227   MlirStringRef identStr = mlirIdentifierStr(ident);
1228   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1229       StringRef(identStr.data, identStr.length));
1230   if (opViewClass)
1231     return (*opViewClass)(getRef().getObject());
1232   return py::cast(PyOpView(getRef().getObject()));
1233 }
1234 
1235 void PyOperation::erase() {
1236   checkValid();
1237   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1238   // Python reference to a child operation is live. All children should also
1239   // have their `valid` bit set to false.
1240   auto &liveOperations = getContext()->liveOperations;
1241   if (liveOperations.count(operation.ptr))
1242     liveOperations.erase(operation.ptr);
1243   mlirOperationDestroy(operation);
1244   valid = false;
1245 }
1246 
1247 //------------------------------------------------------------------------------
1248 // PyOpView
1249 //------------------------------------------------------------------------------
1250 
1251 py::object PyOpView::buildGeneric(
1252     const py::object &cls, py::list resultTypeList, py::list operandList,
1253     llvm::Optional<py::dict> attributes,
1254     llvm::Optional<std::vector<PyBlock *>> successors,
1255     llvm::Optional<int> regions, DefaultingPyLocation location,
1256     const py::object &maybeIp) {
1257   PyMlirContextRef context = location->getContext();
1258   // Class level operation construction metadata.
1259   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1260   // Operand and result segment specs are either none, which does no
1261   // variadic unpacking, or a list of ints with segment sizes, where each
1262   // element is either a positive number (typically 1 for a scalar) or -1 to
1263   // indicate that it is derived from the length of the same-indexed operand
1264   // or result (implying that it is a list at that position).
1265   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1266   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1267 
1268   std::vector<uint32_t> operandSegmentLengths;
1269   std::vector<uint32_t> resultSegmentLengths;
1270 
1271   // Validate/determine region count.
1272   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1273   int opMinRegionCount = std::get<0>(opRegionSpec);
1274   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1275   if (!regions) {
1276     regions = opMinRegionCount;
1277   }
1278   if (*regions < opMinRegionCount) {
1279     throw py::value_error(
1280         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1281          llvm::Twine(opMinRegionCount) +
1282          " regions but was built with regions=" + llvm::Twine(*regions))
1283             .str());
1284   }
1285   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1286     throw py::value_error(
1287         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1288          llvm::Twine(opMinRegionCount) +
1289          " regions but was built with regions=" + llvm::Twine(*regions))
1290             .str());
1291   }
1292 
1293   // Unpack results.
1294   std::vector<PyType *> resultTypes;
1295   resultTypes.reserve(resultTypeList.size());
1296   if (resultSegmentSpecObj.is_none()) {
1297     // Non-variadic result unpacking.
1298     for (const auto &it : llvm::enumerate(resultTypeList)) {
1299       try {
1300         resultTypes.push_back(py::cast<PyType *>(it.value()));
1301         if (!resultTypes.back())
1302           throw py::cast_error();
1303       } catch (py::cast_error &err) {
1304         throw py::value_error((llvm::Twine("Result ") +
1305                                llvm::Twine(it.index()) + " of operation \"" +
1306                                name + "\" must be a Type (" + err.what() + ")")
1307                                   .str());
1308       }
1309     }
1310   } else {
1311     // Sized result unpacking.
1312     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1313     if (resultSegmentSpec.size() != resultTypeList.size()) {
1314       throw py::value_error((llvm::Twine("Operation \"") + name +
1315                              "\" requires " +
1316                              llvm::Twine(resultSegmentSpec.size()) +
1317                              " result segments but was provided " +
1318                              llvm::Twine(resultTypeList.size()))
1319                                 .str());
1320     }
1321     resultSegmentLengths.reserve(resultTypeList.size());
1322     for (const auto &it :
1323          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1324       int segmentSpec = std::get<1>(it.value());
1325       if (segmentSpec == 1 || segmentSpec == 0) {
1326         // Unpack unary element.
1327         try {
1328           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1329           if (resultType) {
1330             resultTypes.push_back(resultType);
1331             resultSegmentLengths.push_back(1);
1332           } else if (segmentSpec == 0) {
1333             // Allowed to be optional.
1334             resultSegmentLengths.push_back(0);
1335           } else {
1336             throw py::cast_error("was None and result is not optional");
1337           }
1338         } catch (py::cast_error &err) {
1339           throw py::value_error((llvm::Twine("Result ") +
1340                                  llvm::Twine(it.index()) + " of operation \"" +
1341                                  name + "\" must be a Type (" + err.what() +
1342                                  ")")
1343                                     .str());
1344         }
1345       } else if (segmentSpec == -1) {
1346         // Unpack sequence by appending.
1347         try {
1348           if (std::get<0>(it.value()).is_none()) {
1349             // Treat it as an empty list.
1350             resultSegmentLengths.push_back(0);
1351           } else {
1352             // Unpack the list.
1353             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1354             for (py::object segmentItem : segment) {
1355               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1356               if (!resultTypes.back()) {
1357                 throw py::cast_error("contained a None item");
1358               }
1359             }
1360             resultSegmentLengths.push_back(segment.size());
1361           }
1362         } catch (std::exception &err) {
1363           // NOTE: Sloppy to be using a catch-all here, but there are at least
1364           // three different unrelated exceptions that can be thrown in the
1365           // above "casts". Just keep the scope above small and catch them all.
1366           throw py::value_error((llvm::Twine("Result ") +
1367                                  llvm::Twine(it.index()) + " of operation \"" +
1368                                  name + "\" must be a Sequence of Types (" +
1369                                  err.what() + ")")
1370                                     .str());
1371         }
1372       } else {
1373         throw py::value_error("Unexpected segment spec");
1374       }
1375     }
1376   }
1377 
1378   // Unpack operands.
1379   std::vector<PyValue *> operands;
1380   operands.reserve(operands.size());
1381   if (operandSegmentSpecObj.is_none()) {
1382     // Non-sized operand unpacking.
1383     for (const auto &it : llvm::enumerate(operandList)) {
1384       try {
1385         operands.push_back(py::cast<PyValue *>(it.value()));
1386         if (!operands.back())
1387           throw py::cast_error();
1388       } catch (py::cast_error &err) {
1389         throw py::value_error((llvm::Twine("Operand ") +
1390                                llvm::Twine(it.index()) + " of operation \"" +
1391                                name + "\" must be a Value (" + err.what() + ")")
1392                                   .str());
1393       }
1394     }
1395   } else {
1396     // Sized operand unpacking.
1397     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1398     if (operandSegmentSpec.size() != operandList.size()) {
1399       throw py::value_error((llvm::Twine("Operation \"") + name +
1400                              "\" requires " +
1401                              llvm::Twine(operandSegmentSpec.size()) +
1402                              "operand segments but was provided " +
1403                              llvm::Twine(operandList.size()))
1404                                 .str());
1405     }
1406     operandSegmentLengths.reserve(operandList.size());
1407     for (const auto &it :
1408          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1409       int segmentSpec = std::get<1>(it.value());
1410       if (segmentSpec == 1 || segmentSpec == 0) {
1411         // Unpack unary element.
1412         try {
1413           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1414           if (operandValue) {
1415             operands.push_back(operandValue);
1416             operandSegmentLengths.push_back(1);
1417           } else if (segmentSpec == 0) {
1418             // Allowed to be optional.
1419             operandSegmentLengths.push_back(0);
1420           } else {
1421             throw py::cast_error("was None and operand is not optional");
1422           }
1423         } catch (py::cast_error &err) {
1424           throw py::value_error((llvm::Twine("Operand ") +
1425                                  llvm::Twine(it.index()) + " of operation \"" +
1426                                  name + "\" must be a Value (" + err.what() +
1427                                  ")")
1428                                     .str());
1429         }
1430       } else if (segmentSpec == -1) {
1431         // Unpack sequence by appending.
1432         try {
1433           if (std::get<0>(it.value()).is_none()) {
1434             // Treat it as an empty list.
1435             operandSegmentLengths.push_back(0);
1436           } else {
1437             // Unpack the list.
1438             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1439             for (py::object segmentItem : segment) {
1440               operands.push_back(py::cast<PyValue *>(segmentItem));
1441               if (!operands.back()) {
1442                 throw py::cast_error("contained a None item");
1443               }
1444             }
1445             operandSegmentLengths.push_back(segment.size());
1446           }
1447         } catch (std::exception &err) {
1448           // NOTE: Sloppy to be using a catch-all here, but there are at least
1449           // three different unrelated exceptions that can be thrown in the
1450           // above "casts". Just keep the scope above small and catch them all.
1451           throw py::value_error((llvm::Twine("Operand ") +
1452                                  llvm::Twine(it.index()) + " of operation \"" +
1453                                  name + "\" must be a Sequence of Values (" +
1454                                  err.what() + ")")
1455                                     .str());
1456         }
1457       } else {
1458         throw py::value_error("Unexpected segment spec");
1459       }
1460     }
1461   }
1462 
1463   // Merge operand/result segment lengths into attributes if needed.
1464   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1465     // Dup.
1466     if (attributes) {
1467       attributes = py::dict(*attributes);
1468     } else {
1469       attributes = py::dict();
1470     }
1471     if (attributes->contains("result_segment_sizes") ||
1472         attributes->contains("operand_segment_sizes")) {
1473       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1474                             "'operand_segment_sizes' attribute is unsupported. "
1475                             "Use Operation.create for such low-level access.");
1476     }
1477 
1478     // Add result_segment_sizes attribute.
1479     if (!resultSegmentLengths.empty()) {
1480       int64_t size = resultSegmentLengths.size();
1481       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1482           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1483           resultSegmentLengths.size(), resultSegmentLengths.data());
1484       (*attributes)["result_segment_sizes"] =
1485           PyAttribute(context, segmentLengthAttr);
1486     }
1487 
1488     // Add operand_segment_sizes attribute.
1489     if (!operandSegmentLengths.empty()) {
1490       int64_t size = operandSegmentLengths.size();
1491       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1492           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1493           operandSegmentLengths.size(), operandSegmentLengths.data());
1494       (*attributes)["operand_segment_sizes"] =
1495           PyAttribute(context, segmentLengthAttr);
1496     }
1497   }
1498 
1499   // Delegate to create.
1500   return PyOperation::create(name,
1501                              /*results=*/std::move(resultTypes),
1502                              /*operands=*/std::move(operands),
1503                              /*attributes=*/std::move(attributes),
1504                              /*successors=*/std::move(successors),
1505                              /*regions=*/*regions, location, maybeIp);
1506 }
1507 
1508 PyOpView::PyOpView(const py::object &operationObject)
1509     // Casting through the PyOperationBase base-class and then back to the
1510     // Operation lets us accept any PyOperationBase subclass.
1511     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1512       operationObject(operation.getRef().getObject()) {}
1513 
1514 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1515   // This is... a little gross. The typical pattern is to have a pure python
1516   // class that extends OpView like:
1517   //   class AddFOp(_cext.ir.OpView):
1518   //     def __init__(self, loc, lhs, rhs):
1519   //       operation = loc.context.create_operation(
1520   //           "addf", lhs, rhs, results=[lhs.type])
1521   //       super().__init__(operation)
1522   //
1523   // I.e. The goal of the user facing type is to provide a nice constructor
1524   // that has complete freedom for the op under construction. This is at odds
1525   // with our other desire to sometimes create this object by just passing an
1526   // operation (to initialize the base class). We could do *arg and **kwargs
1527   // munging to try to make it work, but instead, we synthesize a new class
1528   // on the fly which extends this user class (AddFOp in this example) and
1529   // *give it* the base class's __init__ method, thus bypassing the
1530   // intermediate subclass's __init__ method entirely. While slightly,
1531   // underhanded, this is safe/legal because the type hierarchy has not changed
1532   // (we just added a new leaf) and we aren't mucking around with __new__.
1533   // Typically, this new class will be stored on the original as "_Raw" and will
1534   // be used for casts and other things that need a variant of the class that
1535   // is initialized purely from an operation.
1536   py::object parentMetaclass =
1537       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1538   py::dict attributes;
1539   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1540   // now.
1541   //   auto opViewType = py::type::of<PyOpView>();
1542   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1543   attributes["__init__"] = opViewType.attr("__init__");
1544   py::str origName = userClass.attr("__name__");
1545   py::str newName = py::str("_") + origName;
1546   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1547 }
1548 
1549 //------------------------------------------------------------------------------
1550 // PyInsertionPoint.
1551 //------------------------------------------------------------------------------
1552 
1553 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1554 
1555 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1556     : refOperation(beforeOperationBase.getOperation().getRef()),
1557       block((*refOperation)->getBlock()) {}
1558 
1559 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1560   PyOperation &operation = operationBase.getOperation();
1561   if (operation.isAttached())
1562     throw SetPyError(PyExc_ValueError,
1563                      "Attempt to insert operation that is already attached");
1564   block.getParentOperation()->checkValid();
1565   MlirOperation beforeOp = {nullptr};
1566   if (refOperation) {
1567     // Insert before operation.
1568     (*refOperation)->checkValid();
1569     beforeOp = (*refOperation)->get();
1570   } else {
1571     // Insert at end (before null) is only valid if the block does not
1572     // already end in a known terminator (violating this will cause assertion
1573     // failures later).
1574     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1575       throw py::index_error("Cannot insert operation at the end of a block "
1576                             "that already has a terminator. Did you mean to "
1577                             "use 'InsertionPoint.at_block_terminator(block)' "
1578                             "versus 'InsertionPoint(block)'?");
1579     }
1580   }
1581   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1582   operation.setAttached();
1583 }
1584 
1585 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1586   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1587   if (mlirOperationIsNull(firstOp)) {
1588     // Just insert at end.
1589     return PyInsertionPoint(block);
1590   }
1591 
1592   // Insert before first op.
1593   PyOperationRef firstOpRef = PyOperation::forOperation(
1594       block.getParentOperation()->getContext(), firstOp);
1595   return PyInsertionPoint{block, std::move(firstOpRef)};
1596 }
1597 
1598 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1599   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1600   if (mlirOperationIsNull(terminator))
1601     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1602   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1603       block.getParentOperation()->getContext(), terminator);
1604   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1605 }
1606 
1607 py::object PyInsertionPoint::contextEnter() {
1608   return PyThreadContextEntry::pushInsertionPoint(*this);
1609 }
1610 
1611 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1612                                    const pybind11::object &excVal,
1613                                    const pybind11::object &excTb) {
1614   PyThreadContextEntry::popInsertionPoint(*this);
1615 }
1616 
1617 //------------------------------------------------------------------------------
1618 // PyAttribute.
1619 //------------------------------------------------------------------------------
1620 
1621 bool PyAttribute::operator==(const PyAttribute &other) {
1622   return mlirAttributeEqual(attr, other.attr);
1623 }
1624 
1625 py::object PyAttribute::getCapsule() {
1626   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1627 }
1628 
1629 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1630   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1631   if (mlirAttributeIsNull(rawAttr))
1632     throw py::error_already_set();
1633   return PyAttribute(
1634       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1635 }
1636 
1637 //------------------------------------------------------------------------------
1638 // PyNamedAttribute.
1639 //------------------------------------------------------------------------------
1640 
1641 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1642     : ownedName(new std::string(std::move(ownedName))) {
1643   namedAttr = mlirNamedAttributeGet(
1644       mlirIdentifierGet(mlirAttributeGetContext(attr),
1645                         toMlirStringRef(*this->ownedName)),
1646       attr);
1647 }
1648 
1649 //------------------------------------------------------------------------------
1650 // PyType.
1651 //------------------------------------------------------------------------------
1652 
1653 bool PyType::operator==(const PyType &other) {
1654   return mlirTypeEqual(type, other.type);
1655 }
1656 
1657 py::object PyType::getCapsule() {
1658   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1659 }
1660 
1661 PyType PyType::createFromCapsule(py::object capsule) {
1662   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1663   if (mlirTypeIsNull(rawType))
1664     throw py::error_already_set();
1665   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1666                 rawType);
1667 }
1668 
1669 //------------------------------------------------------------------------------
1670 // PyValue and subclases.
1671 //------------------------------------------------------------------------------
1672 
1673 pybind11::object PyValue::getCapsule() {
1674   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1675 }
1676 
1677 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1678   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1679   if (mlirValueIsNull(value))
1680     throw py::error_already_set();
1681   MlirOperation owner;
1682   if (mlirValueIsAOpResult(value))
1683     owner = mlirOpResultGetOwner(value);
1684   if (mlirValueIsABlockArgument(value))
1685     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1686   if (mlirOperationIsNull(owner))
1687     throw py::error_already_set();
1688   MlirContext ctx = mlirOperationGetContext(owner);
1689   PyOperationRef ownerRef =
1690       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1691   return PyValue(ownerRef, value);
1692 }
1693 
1694 //------------------------------------------------------------------------------
1695 // PySymbolTable.
1696 //------------------------------------------------------------------------------
1697 
1698 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1699     : operation(operation.getOperation().getRef()) {
1700   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1701   if (mlirSymbolTableIsNull(symbolTable)) {
1702     throw py::cast_error("Operation is not a Symbol Table.");
1703   }
1704 }
1705 
1706 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1707   operation->checkValid();
1708   MlirOperation symbol = mlirSymbolTableLookup(
1709       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1710   if (mlirOperationIsNull(symbol))
1711     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1712 
1713   return PyOperation::forOperation(operation->getContext(), symbol,
1714                                    operation.getObject())
1715       ->createOpView();
1716 }
1717 
1718 void PySymbolTable::erase(PyOperationBase &symbol) {
1719   operation->checkValid();
1720   symbol.getOperation().checkValid();
1721   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1722   // The operation is also erased, so we must invalidate it. There may be Python
1723   // references to this operation so we don't want to delete it from the list of
1724   // live operations here.
1725   symbol.getOperation().valid = false;
1726 }
1727 
1728 void PySymbolTable::dunderDel(const std::string &name) {
1729   py::object operation = dunderGetItem(name);
1730   erase(py::cast<PyOperationBase &>(operation));
1731 }
1732 
1733 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1734   operation->checkValid();
1735   symbol.getOperation().checkValid();
1736   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1737       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1738   if (mlirAttributeIsNull(symbolAttr))
1739     throw py::value_error("Expected operation to have a symbol name.");
1740   return PyAttribute(
1741       symbol.getOperation().getContext(),
1742       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1743 }
1744 
1745 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1746   // Op must already be a symbol.
1747   PyOperation &operation = symbol.getOperation();
1748   operation.checkValid();
1749   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1750   MlirAttribute existingNameAttr =
1751       mlirOperationGetAttributeByName(operation.get(), attrName);
1752   if (mlirAttributeIsNull(existingNameAttr))
1753     throw py::value_error("Expected operation to have a symbol name.");
1754   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1755 }
1756 
1757 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1758                                   const std::string &name) {
1759   // Op must already be a symbol.
1760   PyOperation &operation = symbol.getOperation();
1761   operation.checkValid();
1762   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1763   MlirAttribute existingNameAttr =
1764       mlirOperationGetAttributeByName(operation.get(), attrName);
1765   if (mlirAttributeIsNull(existingNameAttr))
1766     throw py::value_error("Expected operation to have a symbol name.");
1767   MlirAttribute newNameAttr =
1768       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1769   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1770 }
1771 
1772 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1773   PyOperation &operation = symbol.getOperation();
1774   operation.checkValid();
1775   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1776   MlirAttribute existingVisAttr =
1777       mlirOperationGetAttributeByName(operation.get(), attrName);
1778   if (mlirAttributeIsNull(existingVisAttr))
1779     throw py::value_error("Expected operation to have a symbol visibility.");
1780   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1781 }
1782 
1783 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1784                                   const std::string &visibility) {
1785   if (visibility != "public" && visibility != "private" &&
1786       visibility != "nested")
1787     throw py::value_error(
1788         "Expected visibility to be 'public', 'private' or 'nested'");
1789   PyOperation &operation = symbol.getOperation();
1790   operation.checkValid();
1791   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1792   MlirAttribute existingVisAttr =
1793       mlirOperationGetAttributeByName(operation.get(), attrName);
1794   if (mlirAttributeIsNull(existingVisAttr))
1795     throw py::value_error("Expected operation to have a symbol visibility.");
1796   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1797                                                toMlirStringRef(visibility));
1798   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1799 }
1800 
1801 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1802                                          const std::string &newSymbol,
1803                                          PyOperationBase &from) {
1804   PyOperation &fromOperation = from.getOperation();
1805   fromOperation.checkValid();
1806   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1807           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1808           from.getOperation())))
1809 
1810     throw py::value_error("Symbol rename failed");
1811 }
1812 
1813 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1814                                      bool allSymUsesVisible,
1815                                      py::object callback) {
1816   PyOperation &fromOperation = from.getOperation();
1817   fromOperation.checkValid();
1818   struct UserData {
1819     PyMlirContextRef context;
1820     py::object callback;
1821     bool gotException;
1822     std::string exceptionWhat;
1823     py::object exceptionType;
1824   };
1825   UserData userData{
1826       fromOperation.getContext(), std::move(callback), false, {}, {}};
1827   mlirSymbolTableWalkSymbolTables(
1828       fromOperation.get(), allSymUsesVisible,
1829       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1830         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1831         auto pyFoundOp =
1832             PyOperation::forOperation(calleeUserData->context, foundOp);
1833         if (calleeUserData->gotException)
1834           return;
1835         try {
1836           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1837         } catch (py::error_already_set &e) {
1838           calleeUserData->gotException = true;
1839           calleeUserData->exceptionWhat = e.what();
1840           calleeUserData->exceptionType = e.type();
1841         }
1842       },
1843       static_cast<void *>(&userData));
1844   if (userData.gotException) {
1845     std::string message("Exception raised in callback: ");
1846     message.append(userData.exceptionWhat);
1847     throw std::runtime_error(message);
1848   }
1849 }
1850 
1851 namespace {
1852 /// CRTP base class for Python MLIR values that subclass Value and should be
1853 /// castable from it. The value hierarchy is one level deep and is not supposed
1854 /// to accommodate other levels unless core MLIR changes.
1855 template <typename DerivedTy>
1856 class PyConcreteValue : public PyValue {
1857 public:
1858   // Derived classes must define statics for:
1859   //   IsAFunctionTy isaFunction
1860   //   const char *pyClassName
1861   // and redefine bindDerived.
1862   using ClassTy = py::class_<DerivedTy, PyValue>;
1863   using IsAFunctionTy = bool (*)(MlirValue);
1864 
1865   PyConcreteValue() = default;
1866   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1867       : PyValue(operationRef, value) {}
1868   PyConcreteValue(PyValue &orig)
1869       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1870 
1871   /// Attempts to cast the original value to the derived type and throws on
1872   /// type mismatches.
1873   static MlirValue castFrom(PyValue &orig) {
1874     if (!DerivedTy::isaFunction(orig.get())) {
1875       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1876       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1877                                              DerivedTy::pyClassName +
1878                                              " (from " + origRepr + ")");
1879     }
1880     return orig.get();
1881   }
1882 
1883   /// Binds the Python module objects to functions of this class.
1884   static void bind(py::module &m) {
1885     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1886     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1887     cls.def_static(
1888         "isinstance",
1889         [](PyValue &otherValue) -> bool {
1890           return DerivedTy::isaFunction(otherValue);
1891         },
1892         py::arg("other_value"));
1893     DerivedTy::bindDerived(cls);
1894   }
1895 
1896   /// Implemented by derived classes to add methods to the Python subclass.
1897   static void bindDerived(ClassTy &m) {}
1898 };
1899 
1900 /// Python wrapper for MlirBlockArgument.
1901 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1902 public:
1903   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1904   static constexpr const char *pyClassName = "BlockArgument";
1905   using PyConcreteValue::PyConcreteValue;
1906 
1907   static void bindDerived(ClassTy &c) {
1908     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1909       return PyBlock(self.getParentOperation(),
1910                      mlirBlockArgumentGetOwner(self.get()));
1911     });
1912     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1913       return mlirBlockArgumentGetArgNumber(self.get());
1914     });
1915     c.def(
1916         "set_type",
1917         [](PyBlockArgument &self, PyType type) {
1918           return mlirBlockArgumentSetType(self.get(), type);
1919         },
1920         py::arg("type"));
1921   }
1922 };
1923 
1924 /// Python wrapper for MlirOpResult.
1925 class PyOpResult : public PyConcreteValue<PyOpResult> {
1926 public:
1927   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1928   static constexpr const char *pyClassName = "OpResult";
1929   using PyConcreteValue::PyConcreteValue;
1930 
1931   static void bindDerived(ClassTy &c) {
1932     c.def_property_readonly("owner", [](PyOpResult &self) {
1933       assert(
1934           mlirOperationEqual(self.getParentOperation()->get(),
1935                              mlirOpResultGetOwner(self.get())) &&
1936           "expected the owner of the value in Python to match that in the IR");
1937       return self.getParentOperation().getObject();
1938     });
1939     c.def_property_readonly("result_number", [](PyOpResult &self) {
1940       return mlirOpResultGetResultNumber(self.get());
1941     });
1942   }
1943 };
1944 
1945 /// Returns the list of types of the values held by container.
1946 template <typename Container>
1947 static std::vector<PyType> getValueTypes(Container &container,
1948                                          PyMlirContextRef &context) {
1949   std::vector<PyType> result;
1950   result.reserve(container.getNumElements());
1951   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1952     result.push_back(
1953         PyType(context, mlirValueGetType(container.getElement(i).get())));
1954   }
1955   return result;
1956 }
1957 
1958 /// A list of block arguments. Internally, these are stored as consecutive
1959 /// elements, random access is cheap. The argument list is associated with the
1960 /// operation that contains the block (detached blocks are not allowed in
1961 /// Python bindings) and extends its lifetime.
1962 class PyBlockArgumentList
1963     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1964 public:
1965   static constexpr const char *pyClassName = "BlockArgumentList";
1966 
1967   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1968                       intptr_t startIndex = 0, intptr_t length = -1,
1969                       intptr_t step = 1)
1970       : Sliceable(startIndex,
1971                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1972                   step),
1973         operation(std::move(operation)), block(block) {}
1974 
1975   /// Returns the number of arguments in the list.
1976   intptr_t getNumElements() {
1977     operation->checkValid();
1978     return mlirBlockGetNumArguments(block);
1979   }
1980 
1981   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1982   PyBlockArgument getElement(intptr_t pos) {
1983     MlirValue argument = mlirBlockGetArgument(block, pos);
1984     return PyBlockArgument(operation, argument);
1985   }
1986 
1987   /// Returns a sublist of this list.
1988   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1989                             intptr_t step) {
1990     return PyBlockArgumentList(operation, block, startIndex, length, step);
1991   }
1992 
1993   static void bindDerived(ClassTy &c) {
1994     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1995       return getValueTypes(self, self.operation->getContext());
1996     });
1997   }
1998 
1999 private:
2000   PyOperationRef operation;
2001   MlirBlock block;
2002 };
2003 
2004 /// A list of operation operands. Internally, these are stored as consecutive
2005 /// elements, random access is cheap. The result list is associated with the
2006 /// operation whose results these are, and extends the lifetime of this
2007 /// operation.
2008 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2009 public:
2010   static constexpr const char *pyClassName = "OpOperandList";
2011 
2012   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2013                   intptr_t length = -1, intptr_t step = 1)
2014       : Sliceable(startIndex,
2015                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2016                                : length,
2017                   step),
2018         operation(operation) {}
2019 
2020   intptr_t getNumElements() {
2021     operation->checkValid();
2022     return mlirOperationGetNumOperands(operation->get());
2023   }
2024 
2025   PyValue getElement(intptr_t pos) {
2026     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2027     MlirOperation owner;
2028     if (mlirValueIsAOpResult(operand))
2029       owner = mlirOpResultGetOwner(operand);
2030     else if (mlirValueIsABlockArgument(operand))
2031       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2032     else
2033       assert(false && "Value must be an block arg or op result.");
2034     PyOperationRef pyOwner =
2035         PyOperation::forOperation(operation->getContext(), owner);
2036     return PyValue(pyOwner, operand);
2037   }
2038 
2039   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2040     return PyOpOperandList(operation, startIndex, length, step);
2041   }
2042 
2043   void dunderSetItem(intptr_t index, PyValue value) {
2044     index = wrapIndex(index);
2045     mlirOperationSetOperand(operation->get(), index, value.get());
2046   }
2047 
2048   static void bindDerived(ClassTy &c) {
2049     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2050   }
2051 
2052 private:
2053   PyOperationRef operation;
2054 };
2055 
2056 /// A list of operation results. Internally, these are stored as consecutive
2057 /// elements, random access is cheap. The result list is associated with the
2058 /// operation whose results these are, and extends the lifetime of this
2059 /// operation.
2060 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2061 public:
2062   static constexpr const char *pyClassName = "OpResultList";
2063 
2064   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2065                  intptr_t length = -1, intptr_t step = 1)
2066       : Sliceable(startIndex,
2067                   length == -1 ? mlirOperationGetNumResults(operation->get())
2068                                : length,
2069                   step),
2070         operation(operation) {}
2071 
2072   intptr_t getNumElements() {
2073     operation->checkValid();
2074     return mlirOperationGetNumResults(operation->get());
2075   }
2076 
2077   PyOpResult getElement(intptr_t index) {
2078     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2079     return PyOpResult(value);
2080   }
2081 
2082   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2083     return PyOpResultList(operation, startIndex, length, step);
2084   }
2085 
2086   static void bindDerived(ClassTy &c) {
2087     c.def_property_readonly("types", [](PyOpResultList &self) {
2088       return getValueTypes(self, self.operation->getContext());
2089     });
2090   }
2091 
2092 private:
2093   PyOperationRef operation;
2094 };
2095 
2096 /// A list of operation attributes. Can be indexed by name, producing
2097 /// attributes, or by index, producing named attributes.
2098 class PyOpAttributeMap {
2099 public:
2100   PyOpAttributeMap(PyOperationRef operation)
2101       : operation(std::move(operation)) {}
2102 
2103   PyAttribute dunderGetItemNamed(const std::string &name) {
2104     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2105                                                          toMlirStringRef(name));
2106     if (mlirAttributeIsNull(attr)) {
2107       throw SetPyError(PyExc_KeyError,
2108                        "attempt to access a non-existent attribute");
2109     }
2110     return PyAttribute(operation->getContext(), attr);
2111   }
2112 
2113   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2114     if (index < 0 || index >= dunderLen()) {
2115       throw SetPyError(PyExc_IndexError,
2116                        "attempt to access out of bounds attribute");
2117     }
2118     MlirNamedAttribute namedAttr =
2119         mlirOperationGetAttribute(operation->get(), index);
2120     return PyNamedAttribute(
2121         namedAttr.attribute,
2122         std::string(mlirIdentifierStr(namedAttr.name).data,
2123                     mlirIdentifierStr(namedAttr.name).length));
2124   }
2125 
2126   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2127     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2128                                     attr);
2129   }
2130 
2131   void dunderDelItem(const std::string &name) {
2132     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2133                                                      toMlirStringRef(name));
2134     if (!removed)
2135       throw SetPyError(PyExc_KeyError,
2136                        "attempt to delete a non-existent attribute");
2137   }
2138 
2139   intptr_t dunderLen() {
2140     return mlirOperationGetNumAttributes(operation->get());
2141   }
2142 
2143   bool dunderContains(const std::string &name) {
2144     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2145         operation->get(), toMlirStringRef(name)));
2146   }
2147 
2148   static void bind(py::module &m) {
2149     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2150         .def("__contains__", &PyOpAttributeMap::dunderContains)
2151         .def("__len__", &PyOpAttributeMap::dunderLen)
2152         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2153         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2154         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2155         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2156   }
2157 
2158 private:
2159   PyOperationRef operation;
2160 };
2161 
2162 } // namespace
2163 
2164 //------------------------------------------------------------------------------
2165 // Populates the core exports of the 'ir' submodule.
2166 //------------------------------------------------------------------------------
2167 
2168 void mlir::python::populateIRCore(py::module &m) {
2169   //----------------------------------------------------------------------------
2170   // Enums.
2171   //----------------------------------------------------------------------------
2172   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2173       .value("ERROR", MlirDiagnosticError)
2174       .value("WARNING", MlirDiagnosticWarning)
2175       .value("NOTE", MlirDiagnosticNote)
2176       .value("REMARK", MlirDiagnosticRemark);
2177 
2178   //----------------------------------------------------------------------------
2179   // Mapping of Diagnostics.
2180   //----------------------------------------------------------------------------
2181   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2182       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2183       .def_property_readonly("location", &PyDiagnostic::getLocation)
2184       .def_property_readonly("message", &PyDiagnostic::getMessage)
2185       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2186       .def("__str__", [](PyDiagnostic &self) -> py::str {
2187         if (!self.isValid())
2188           return "<Invalid Diagnostic>";
2189         return self.getMessage();
2190       });
2191 
2192   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2193       .def("detach", &PyDiagnosticHandler::detach)
2194       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2195       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2196       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2197       .def("__exit__", &PyDiagnosticHandler::contextExit);
2198 
2199   //----------------------------------------------------------------------------
2200   // Mapping of MlirContext.
2201   //----------------------------------------------------------------------------
2202   py::class_<PyMlirContext>(m, "Context", py::module_local())
2203       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2204       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2205       .def("_get_context_again",
2206            [](PyMlirContext &self) {
2207              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2208              return ref.releaseObject();
2209            })
2210       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2211       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2212       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2213                              &PyMlirContext::getCapsule)
2214       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2215       .def("__enter__", &PyMlirContext::contextEnter)
2216       .def("__exit__", &PyMlirContext::contextExit)
2217       .def_property_readonly_static(
2218           "current",
2219           [](py::object & /*class*/) {
2220             auto *context = PyThreadContextEntry::getDefaultContext();
2221             if (!context)
2222               throw SetPyError(PyExc_ValueError, "No current Context");
2223             return context;
2224           },
2225           "Gets the Context bound to the current thread or raises ValueError")
2226       .def_property_readonly(
2227           "dialects",
2228           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2229           "Gets a container for accessing dialects by name")
2230       .def_property_readonly(
2231           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2232           "Alias for 'dialect'")
2233       .def(
2234           "get_dialect_descriptor",
2235           [=](PyMlirContext &self, std::string &name) {
2236             MlirDialect dialect = mlirContextGetOrLoadDialect(
2237                 self.get(), {name.data(), name.size()});
2238             if (mlirDialectIsNull(dialect)) {
2239               throw SetPyError(PyExc_ValueError,
2240                                Twine("Dialect '") + name + "' not found");
2241             }
2242             return PyDialectDescriptor(self.getRef(), dialect);
2243           },
2244           py::arg("dialect_name"),
2245           "Gets or loads a dialect by name, returning its descriptor object")
2246       .def_property(
2247           "allow_unregistered_dialects",
2248           [](PyMlirContext &self) -> bool {
2249             return mlirContextGetAllowUnregisteredDialects(self.get());
2250           },
2251           [](PyMlirContext &self, bool value) {
2252             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2253           })
2254       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2255            py::arg("callback"),
2256            "Attaches a diagnostic handler that will receive callbacks")
2257       .def(
2258           "enable_multithreading",
2259           [](PyMlirContext &self, bool enable) {
2260             mlirContextEnableMultithreading(self.get(), enable);
2261           },
2262           py::arg("enable"))
2263       .def(
2264           "is_registered_operation",
2265           [](PyMlirContext &self, std::string &name) {
2266             return mlirContextIsRegisteredOperation(
2267                 self.get(), MlirStringRef{name.data(), name.size()});
2268           },
2269           py::arg("operation_name"));
2270 
2271   //----------------------------------------------------------------------------
2272   // Mapping of PyDialectDescriptor
2273   //----------------------------------------------------------------------------
2274   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2275       .def_property_readonly("namespace",
2276                              [](PyDialectDescriptor &self) {
2277                                MlirStringRef ns =
2278                                    mlirDialectGetNamespace(self.get());
2279                                return py::str(ns.data, ns.length);
2280                              })
2281       .def("__repr__", [](PyDialectDescriptor &self) {
2282         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2283         std::string repr("<DialectDescriptor ");
2284         repr.append(ns.data, ns.length);
2285         repr.append(">");
2286         return repr;
2287       });
2288 
2289   //----------------------------------------------------------------------------
2290   // Mapping of PyDialects
2291   //----------------------------------------------------------------------------
2292   py::class_<PyDialects>(m, "Dialects", py::module_local())
2293       .def("__getitem__",
2294            [=](PyDialects &self, std::string keyName) {
2295              MlirDialect dialect =
2296                  self.getDialectForKey(keyName, /*attrError=*/false);
2297              py::object descriptor =
2298                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2299              return createCustomDialectWrapper(keyName, std::move(descriptor));
2300            })
2301       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2302         MlirDialect dialect =
2303             self.getDialectForKey(attrName, /*attrError=*/true);
2304         py::object descriptor =
2305             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2306         return createCustomDialectWrapper(attrName, std::move(descriptor));
2307       });
2308 
2309   //----------------------------------------------------------------------------
2310   // Mapping of PyDialect
2311   //----------------------------------------------------------------------------
2312   py::class_<PyDialect>(m, "Dialect", py::module_local())
2313       .def(py::init<py::object>(), py::arg("descriptor"))
2314       .def_property_readonly(
2315           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2316       .def("__repr__", [](py::object self) {
2317         auto clazz = self.attr("__class__");
2318         return py::str("<Dialect ") +
2319                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2320                clazz.attr("__module__") + py::str(".") +
2321                clazz.attr("__name__") + py::str(")>");
2322       });
2323 
2324   //----------------------------------------------------------------------------
2325   // Mapping of Location
2326   //----------------------------------------------------------------------------
2327   py::class_<PyLocation>(m, "Location", py::module_local())
2328       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2329       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2330       .def("__enter__", &PyLocation::contextEnter)
2331       .def("__exit__", &PyLocation::contextExit)
2332       .def("__eq__",
2333            [](PyLocation &self, PyLocation &other) -> bool {
2334              return mlirLocationEqual(self, other);
2335            })
2336       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2337       .def_property_readonly_static(
2338           "current",
2339           [](py::object & /*class*/) {
2340             auto *loc = PyThreadContextEntry::getDefaultLocation();
2341             if (!loc)
2342               throw SetPyError(PyExc_ValueError, "No current Location");
2343             return loc;
2344           },
2345           "Gets the Location bound to the current thread or raises ValueError")
2346       .def_static(
2347           "unknown",
2348           [](DefaultingPyMlirContext context) {
2349             return PyLocation(context->getRef(),
2350                               mlirLocationUnknownGet(context->get()));
2351           },
2352           py::arg("context") = py::none(),
2353           "Gets a Location representing an unknown location")
2354       .def_static(
2355           "callsite",
2356           [](PyLocation callee, const std::vector<PyLocation> &frames,
2357              DefaultingPyMlirContext context) {
2358             if (frames.empty())
2359               throw py::value_error("No caller frames provided");
2360             MlirLocation caller = frames.back().get();
2361             for (const PyLocation &frame :
2362                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2363               caller = mlirLocationCallSiteGet(frame.get(), caller);
2364             return PyLocation(context->getRef(),
2365                               mlirLocationCallSiteGet(callee.get(), caller));
2366           },
2367           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2368           kContextGetCallSiteLocationDocstring)
2369       .def_static(
2370           "file",
2371           [](std::string filename, int line, int col,
2372              DefaultingPyMlirContext context) {
2373             return PyLocation(
2374                 context->getRef(),
2375                 mlirLocationFileLineColGet(
2376                     context->get(), toMlirStringRef(filename), line, col));
2377           },
2378           py::arg("filename"), py::arg("line"), py::arg("col"),
2379           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2380       .def_static(
2381           "fused",
2382           [](const std::vector<PyLocation> &pyLocations,
2383              llvm::Optional<PyAttribute> metadata,
2384              DefaultingPyMlirContext context) {
2385             llvm::SmallVector<MlirLocation, 4> locations;
2386             locations.reserve(pyLocations.size());
2387             for (auto &pyLocation : pyLocations)
2388               locations.push_back(pyLocation.get());
2389             MlirLocation location = mlirLocationFusedGet(
2390                 context->get(), locations.size(), locations.data(),
2391                 metadata ? metadata->get() : MlirAttribute{0});
2392             return PyLocation(context->getRef(), location);
2393           },
2394           py::arg("locations"), py::arg("metadata") = py::none(),
2395           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2396       .def_static(
2397           "name",
2398           [](std::string name, llvm::Optional<PyLocation> childLoc,
2399              DefaultingPyMlirContext context) {
2400             return PyLocation(
2401                 context->getRef(),
2402                 mlirLocationNameGet(
2403                     context->get(), toMlirStringRef(name),
2404                     childLoc ? childLoc->get()
2405                              : mlirLocationUnknownGet(context->get())));
2406           },
2407           py::arg("name"), py::arg("childLoc") = py::none(),
2408           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2409       .def_property_readonly(
2410           "context",
2411           [](PyLocation &self) { return self.getContext().getObject(); },
2412           "Context that owns the Location")
2413       .def(
2414           "emit_error",
2415           [](PyLocation &self, std::string message) {
2416             mlirEmitError(self, message.c_str());
2417           },
2418           py::arg("message"), "Emits an error at this location")
2419       .def("__repr__", [](PyLocation &self) {
2420         PyPrintAccumulator printAccum;
2421         mlirLocationPrint(self, printAccum.getCallback(),
2422                           printAccum.getUserData());
2423         return printAccum.join();
2424       });
2425 
2426   //----------------------------------------------------------------------------
2427   // Mapping of Module
2428   //----------------------------------------------------------------------------
2429   py::class_<PyModule>(m, "Module", py::module_local())
2430       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2431       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2432       .def_static(
2433           "parse",
2434           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2435             MlirModule module = mlirModuleCreateParse(
2436                 context->get(), toMlirStringRef(moduleAsm));
2437             // TODO: Rework error reporting once diagnostic engine is exposed
2438             // in C API.
2439             if (mlirModuleIsNull(module)) {
2440               throw SetPyError(
2441                   PyExc_ValueError,
2442                   "Unable to parse module assembly (see diagnostics)");
2443             }
2444             return PyModule::forModule(module).releaseObject();
2445           },
2446           py::arg("asm"), py::arg("context") = py::none(),
2447           kModuleParseDocstring)
2448       .def_static(
2449           "create",
2450           [](DefaultingPyLocation loc) {
2451             MlirModule module = mlirModuleCreateEmpty(loc);
2452             return PyModule::forModule(module).releaseObject();
2453           },
2454           py::arg("loc") = py::none(), "Creates an empty module")
2455       .def_property_readonly(
2456           "context",
2457           [](PyModule &self) { return self.getContext().getObject(); },
2458           "Context that created the Module")
2459       .def_property_readonly(
2460           "operation",
2461           [](PyModule &self) {
2462             return PyOperation::forOperation(self.getContext(),
2463                                              mlirModuleGetOperation(self.get()),
2464                                              self.getRef().releaseObject())
2465                 .releaseObject();
2466           },
2467           "Accesses the module as an operation")
2468       .def_property_readonly(
2469           "body",
2470           [](PyModule &self) {
2471             PyOperationRef moduleOp = PyOperation::forOperation(
2472                 self.getContext(), mlirModuleGetOperation(self.get()),
2473                 self.getRef().releaseObject());
2474             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2475             return returnBlock;
2476           },
2477           "Return the block for this module")
2478       .def(
2479           "dump",
2480           [](PyModule &self) {
2481             mlirOperationDump(mlirModuleGetOperation(self.get()));
2482           },
2483           kDumpDocstring)
2484       .def(
2485           "__str__",
2486           [](py::object self) {
2487             // Defer to the operation's __str__.
2488             return self.attr("operation").attr("__str__")();
2489           },
2490           kOperationStrDunderDocstring);
2491 
2492   //----------------------------------------------------------------------------
2493   // Mapping of Operation.
2494   //----------------------------------------------------------------------------
2495   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2496       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2497                              [](PyOperationBase &self) {
2498                                return self.getOperation().getCapsule();
2499                              })
2500       .def("__eq__",
2501            [](PyOperationBase &self, PyOperationBase &other) {
2502              return &self.getOperation() == &other.getOperation();
2503            })
2504       .def("__eq__",
2505            [](PyOperationBase &self, py::object other) { return false; })
2506       .def("__hash__",
2507            [](PyOperationBase &self) {
2508              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2509            })
2510       .def_property_readonly("attributes",
2511                              [](PyOperationBase &self) {
2512                                return PyOpAttributeMap(
2513                                    self.getOperation().getRef());
2514                              })
2515       .def_property_readonly("operands",
2516                              [](PyOperationBase &self) {
2517                                return PyOpOperandList(
2518                                    self.getOperation().getRef());
2519                              })
2520       .def_property_readonly("regions",
2521                              [](PyOperationBase &self) {
2522                                return PyRegionList(
2523                                    self.getOperation().getRef());
2524                              })
2525       .def_property_readonly(
2526           "results",
2527           [](PyOperationBase &self) {
2528             return PyOpResultList(self.getOperation().getRef());
2529           },
2530           "Returns the list of Operation results.")
2531       .def_property_readonly(
2532           "result",
2533           [](PyOperationBase &self) {
2534             auto &operation = self.getOperation();
2535             auto numResults = mlirOperationGetNumResults(operation);
2536             if (numResults != 1) {
2537               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2538               throw SetPyError(
2539                   PyExc_ValueError,
2540                   Twine("Cannot call .result on operation ") +
2541                       StringRef(name.data, name.length) + " which has " +
2542                       Twine(numResults) +
2543                       " results (it is only valid for operations with a "
2544                       "single result)");
2545             }
2546             return PyOpResult(operation.getRef(),
2547                               mlirOperationGetResult(operation, 0));
2548           },
2549           "Shortcut to get an op result if it has only one (throws an error "
2550           "otherwise).")
2551       .def_property_readonly(
2552           "location",
2553           [](PyOperationBase &self) {
2554             PyOperation &operation = self.getOperation();
2555             return PyLocation(operation.getContext(),
2556                               mlirOperationGetLocation(operation.get()));
2557           },
2558           "Returns the source location the operation was defined or derived "
2559           "from.")
2560       .def(
2561           "__str__",
2562           [](PyOperationBase &self) {
2563             return self.getAsm(/*binary=*/false,
2564                                /*largeElementsLimit=*/llvm::None,
2565                                /*enableDebugInfo=*/false,
2566                                /*prettyDebugInfo=*/false,
2567                                /*printGenericOpForm=*/false,
2568                                /*useLocalScope=*/false,
2569                                /*assumeVerified=*/false);
2570           },
2571           "Returns the assembly form of the operation.")
2572       .def("print", &PyOperationBase::print,
2573            // Careful: Lots of arguments must match up with print method.
2574            py::arg("file") = py::none(), py::arg("binary") = false,
2575            py::arg("large_elements_limit") = py::none(),
2576            py::arg("enable_debug_info") = false,
2577            py::arg("pretty_debug_info") = false,
2578            py::arg("print_generic_op_form") = false,
2579            py::arg("use_local_scope") = false,
2580            py::arg("assume_verified") = false, kOperationPrintDocstring)
2581       .def("get_asm", &PyOperationBase::getAsm,
2582            // Careful: Lots of arguments must match up with get_asm method.
2583            py::arg("binary") = false,
2584            py::arg("large_elements_limit") = py::none(),
2585            py::arg("enable_debug_info") = false,
2586            py::arg("pretty_debug_info") = false,
2587            py::arg("print_generic_op_form") = false,
2588            py::arg("use_local_scope") = false,
2589            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2590       .def(
2591           "verify",
2592           [](PyOperationBase &self) {
2593             return mlirOperationVerify(self.getOperation());
2594           },
2595           "Verify the operation and return true if it passes, false if it "
2596           "fails.")
2597       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2598            "Puts self immediately after the other operation in its parent "
2599            "block.")
2600       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2601            "Puts self immediately before the other operation in its parent "
2602            "block.")
2603       .def(
2604           "detach_from_parent",
2605           [](PyOperationBase &self) {
2606             PyOperation &operation = self.getOperation();
2607             operation.checkValid();
2608             if (!operation.isAttached())
2609               throw py::value_error("Detached operation has no parent.");
2610 
2611             operation.detachFromParent();
2612             return operation.createOpView();
2613           },
2614           "Detaches the operation from its parent block.");
2615 
2616   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2617       .def_static("create", &PyOperation::create, py::arg("name"),
2618                   py::arg("results") = py::none(),
2619                   py::arg("operands") = py::none(),
2620                   py::arg("attributes") = py::none(),
2621                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2622                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2623                   kOperationCreateDocstring)
2624       .def_property_readonly("parent",
2625                              [](PyOperation &self) -> py::object {
2626                                auto parent = self.getParentOperation();
2627                                if (parent)
2628                                  return parent->getObject();
2629                                return py::none();
2630                              })
2631       .def("erase", &PyOperation::erase)
2632       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2633       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2634                              &PyOperation::getCapsule)
2635       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2636       .def_property_readonly("name",
2637                              [](PyOperation &self) {
2638                                self.checkValid();
2639                                MlirOperation operation = self.get();
2640                                MlirStringRef name = mlirIdentifierStr(
2641                                    mlirOperationGetName(operation));
2642                                return py::str(name.data, name.length);
2643                              })
2644       .def_property_readonly(
2645           "context",
2646           [](PyOperation &self) {
2647             self.checkValid();
2648             return self.getContext().getObject();
2649           },
2650           "Context that owns the Operation")
2651       .def_property_readonly("opview", &PyOperation::createOpView);
2652 
2653   auto opViewClass =
2654       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2655           .def(py::init<py::object>(), py::arg("operation"))
2656           .def_property_readonly("operation", &PyOpView::getOperationObject)
2657           .def_property_readonly(
2658               "context",
2659               [](PyOpView &self) {
2660                 return self.getOperation().getContext().getObject();
2661               },
2662               "Context that owns the Operation")
2663           .def("__str__", [](PyOpView &self) {
2664             return py::str(self.getOperationObject());
2665           });
2666   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2667   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2668   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2669   opViewClass.attr("build_generic") = classmethod(
2670       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2671       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2672       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2673       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2674       "Builds a specific, generated OpView based on class level attributes.");
2675 
2676   //----------------------------------------------------------------------------
2677   // Mapping of PyRegion.
2678   //----------------------------------------------------------------------------
2679   py::class_<PyRegion>(m, "Region", py::module_local())
2680       .def_property_readonly(
2681           "blocks",
2682           [](PyRegion &self) {
2683             return PyBlockList(self.getParentOperation(), self.get());
2684           },
2685           "Returns a forward-optimized sequence of blocks.")
2686       .def_property_readonly(
2687           "owner",
2688           [](PyRegion &self) {
2689             return self.getParentOperation()->createOpView();
2690           },
2691           "Returns the operation owning this region.")
2692       .def(
2693           "__iter__",
2694           [](PyRegion &self) {
2695             self.checkValid();
2696             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2697             return PyBlockIterator(self.getParentOperation(), firstBlock);
2698           },
2699           "Iterates over blocks in the region.")
2700       .def("__eq__",
2701            [](PyRegion &self, PyRegion &other) {
2702              return self.get().ptr == other.get().ptr;
2703            })
2704       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2705 
2706   //----------------------------------------------------------------------------
2707   // Mapping of PyBlock.
2708   //----------------------------------------------------------------------------
2709   py::class_<PyBlock>(m, "Block", py::module_local())
2710       .def_property_readonly(
2711           "owner",
2712           [](PyBlock &self) {
2713             return self.getParentOperation()->createOpView();
2714           },
2715           "Returns the owning operation of this block.")
2716       .def_property_readonly(
2717           "region",
2718           [](PyBlock &self) {
2719             MlirRegion region = mlirBlockGetParentRegion(self.get());
2720             return PyRegion(self.getParentOperation(), region);
2721           },
2722           "Returns the owning region of this block.")
2723       .def_property_readonly(
2724           "arguments",
2725           [](PyBlock &self) {
2726             return PyBlockArgumentList(self.getParentOperation(), self.get());
2727           },
2728           "Returns a list of block arguments.")
2729       .def_property_readonly(
2730           "operations",
2731           [](PyBlock &self) {
2732             return PyOperationList(self.getParentOperation(), self.get());
2733           },
2734           "Returns a forward-optimized sequence of operations.")
2735       .def_static(
2736           "create_at_start",
2737           [](PyRegion &parent, py::list pyArgTypes) {
2738             parent.checkValid();
2739             llvm::SmallVector<MlirType, 4> argTypes;
2740             llvm::SmallVector<MlirLocation, 4> argLocs;
2741             argTypes.reserve(pyArgTypes.size());
2742             argLocs.reserve(pyArgTypes.size());
2743             for (auto &pyArg : pyArgTypes) {
2744               argTypes.push_back(pyArg.cast<PyType &>());
2745               // TODO: Pass in a proper location here.
2746               argLocs.push_back(
2747                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2748             }
2749 
2750             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2751                                               argLocs.data());
2752             mlirRegionInsertOwnedBlock(parent, 0, block);
2753             return PyBlock(parent.getParentOperation(), block);
2754           },
2755           py::arg("parent"), py::arg("arg_types") = py::list(),
2756           "Creates and returns a new Block at the beginning of the given "
2757           "region (with given argument types).")
2758       .def(
2759           "append_to",
2760           [](PyBlock &self, PyRegion &region) {
2761             MlirBlock b = self.get();
2762             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2763               mlirBlockDetach(b);
2764             mlirRegionAppendOwnedBlock(region.get(), b);
2765           },
2766           "Append this block to a region, transferring ownership if necessary")
2767       .def(
2768           "create_before",
2769           [](PyBlock &self, py::args pyArgTypes) {
2770             self.checkValid();
2771             llvm::SmallVector<MlirType, 4> argTypes;
2772             llvm::SmallVector<MlirLocation, 4> argLocs;
2773             argTypes.reserve(pyArgTypes.size());
2774             argLocs.reserve(pyArgTypes.size());
2775             for (auto &pyArg : pyArgTypes) {
2776               argTypes.push_back(pyArg.cast<PyType &>());
2777               // TODO: Pass in a proper location here.
2778               argLocs.push_back(
2779                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2780             }
2781 
2782             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2783                                               argLocs.data());
2784             MlirRegion region = mlirBlockGetParentRegion(self.get());
2785             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2786             return PyBlock(self.getParentOperation(), block);
2787           },
2788           "Creates and returns a new Block before this block "
2789           "(with given argument types).")
2790       .def(
2791           "create_after",
2792           [](PyBlock &self, py::args pyArgTypes) {
2793             self.checkValid();
2794             llvm::SmallVector<MlirType, 4> argTypes;
2795             llvm::SmallVector<MlirLocation, 4> argLocs;
2796             argTypes.reserve(pyArgTypes.size());
2797             argLocs.reserve(pyArgTypes.size());
2798             for (auto &pyArg : pyArgTypes) {
2799               argTypes.push_back(pyArg.cast<PyType &>());
2800 
2801               // TODO: Pass in a proper location here.
2802               argLocs.push_back(
2803                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2804             }
2805             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2806                                               argLocs.data());
2807             MlirRegion region = mlirBlockGetParentRegion(self.get());
2808             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2809             return PyBlock(self.getParentOperation(), block);
2810           },
2811           "Creates and returns a new Block after this block "
2812           "(with given argument types).")
2813       .def(
2814           "__iter__",
2815           [](PyBlock &self) {
2816             self.checkValid();
2817             MlirOperation firstOperation =
2818                 mlirBlockGetFirstOperation(self.get());
2819             return PyOperationIterator(self.getParentOperation(),
2820                                        firstOperation);
2821           },
2822           "Iterates over operations in the block.")
2823       .def("__eq__",
2824            [](PyBlock &self, PyBlock &other) {
2825              return self.get().ptr == other.get().ptr;
2826            })
2827       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2828       .def(
2829           "__str__",
2830           [](PyBlock &self) {
2831             self.checkValid();
2832             PyPrintAccumulator printAccum;
2833             mlirBlockPrint(self.get(), printAccum.getCallback(),
2834                            printAccum.getUserData());
2835             return printAccum.join();
2836           },
2837           "Returns the assembly form of the block.")
2838       .def(
2839           "append",
2840           [](PyBlock &self, PyOperationBase &operation) {
2841             if (operation.getOperation().isAttached())
2842               operation.getOperation().detachFromParent();
2843 
2844             MlirOperation mlirOperation = operation.getOperation().get();
2845             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2846             operation.getOperation().setAttached(
2847                 self.getParentOperation().getObject());
2848           },
2849           py::arg("operation"),
2850           "Appends an operation to this block. If the operation is currently "
2851           "in another block, it will be moved.");
2852 
2853   //----------------------------------------------------------------------------
2854   // Mapping of PyInsertionPoint.
2855   //----------------------------------------------------------------------------
2856 
2857   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2858       .def(py::init<PyBlock &>(), py::arg("block"),
2859            "Inserts after the last operation but still inside the block.")
2860       .def("__enter__", &PyInsertionPoint::contextEnter)
2861       .def("__exit__", &PyInsertionPoint::contextExit)
2862       .def_property_readonly_static(
2863           "current",
2864           [](py::object & /*class*/) {
2865             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2866             if (!ip)
2867               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2868             return ip;
2869           },
2870           "Gets the InsertionPoint bound to the current thread or raises "
2871           "ValueError if none has been set")
2872       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2873            "Inserts before a referenced operation.")
2874       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2875                   py::arg("block"), "Inserts at the beginning of the block.")
2876       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2877                   py::arg("block"), "Inserts before the block terminator.")
2878       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2879            "Inserts an operation.")
2880       .def_property_readonly(
2881           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2882           "Returns the block that this InsertionPoint points to.");
2883 
2884   //----------------------------------------------------------------------------
2885   // Mapping of PyAttribute.
2886   //----------------------------------------------------------------------------
2887   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2888       // Delegate to the PyAttribute copy constructor, which will also lifetime
2889       // extend the backing context which owns the MlirAttribute.
2890       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2891            "Casts the passed attribute to the generic Attribute")
2892       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2893                              &PyAttribute::getCapsule)
2894       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2895       .def_static(
2896           "parse",
2897           [](std::string attrSpec, DefaultingPyMlirContext context) {
2898             MlirAttribute type = mlirAttributeParseGet(
2899                 context->get(), toMlirStringRef(attrSpec));
2900             // TODO: Rework error reporting once diagnostic engine is exposed
2901             // in C API.
2902             if (mlirAttributeIsNull(type)) {
2903               throw SetPyError(PyExc_ValueError,
2904                                Twine("Unable to parse attribute: '") +
2905                                    attrSpec + "'");
2906             }
2907             return PyAttribute(context->getRef(), type);
2908           },
2909           py::arg("asm"), py::arg("context") = py::none(),
2910           "Parses an attribute from an assembly form")
2911       .def_property_readonly(
2912           "context",
2913           [](PyAttribute &self) { return self.getContext().getObject(); },
2914           "Context that owns the Attribute")
2915       .def_property_readonly("type",
2916                              [](PyAttribute &self) {
2917                                return PyType(self.getContext()->getRef(),
2918                                              mlirAttributeGetType(self));
2919                              })
2920       .def(
2921           "get_named",
2922           [](PyAttribute &self, std::string name) {
2923             return PyNamedAttribute(self, std::move(name));
2924           },
2925           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2926       .def("__eq__",
2927            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2928       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2929       .def("__hash__",
2930            [](PyAttribute &self) {
2931              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2932            })
2933       .def(
2934           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2935           kDumpDocstring)
2936       .def(
2937           "__str__",
2938           [](PyAttribute &self) {
2939             PyPrintAccumulator printAccum;
2940             mlirAttributePrint(self, printAccum.getCallback(),
2941                                printAccum.getUserData());
2942             return printAccum.join();
2943           },
2944           "Returns the assembly form of the Attribute.")
2945       .def("__repr__", [](PyAttribute &self) {
2946         // Generally, assembly formats are not printed for __repr__ because
2947         // this can cause exceptionally long debug output and exceptions.
2948         // However, attribute values are generally considered useful and are
2949         // printed. This may need to be re-evaluated if debug dumps end up
2950         // being excessive.
2951         PyPrintAccumulator printAccum;
2952         printAccum.parts.append("Attribute(");
2953         mlirAttributePrint(self, printAccum.getCallback(),
2954                            printAccum.getUserData());
2955         printAccum.parts.append(")");
2956         return printAccum.join();
2957       });
2958 
2959   //----------------------------------------------------------------------------
2960   // Mapping of PyNamedAttribute
2961   //----------------------------------------------------------------------------
2962   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2963       .def("__repr__",
2964            [](PyNamedAttribute &self) {
2965              PyPrintAccumulator printAccum;
2966              printAccum.parts.append("NamedAttribute(");
2967              printAccum.parts.append(
2968                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2969                          mlirIdentifierStr(self.namedAttr.name).length));
2970              printAccum.parts.append("=");
2971              mlirAttributePrint(self.namedAttr.attribute,
2972                                 printAccum.getCallback(),
2973                                 printAccum.getUserData());
2974              printAccum.parts.append(")");
2975              return printAccum.join();
2976            })
2977       .def_property_readonly(
2978           "name",
2979           [](PyNamedAttribute &self) {
2980             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2981                            mlirIdentifierStr(self.namedAttr.name).length);
2982           },
2983           "The name of the NamedAttribute binding")
2984       .def_property_readonly(
2985           "attr",
2986           [](PyNamedAttribute &self) {
2987             // TODO: When named attribute is removed/refactored, also remove
2988             // this constructor (it does an inefficient table lookup).
2989             auto contextRef = PyMlirContext::forContext(
2990                 mlirAttributeGetContext(self.namedAttr.attribute));
2991             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2992           },
2993           py::keep_alive<0, 1>(),
2994           "The underlying generic attribute of the NamedAttribute binding");
2995 
2996   //----------------------------------------------------------------------------
2997   // Mapping of PyType.
2998   //----------------------------------------------------------------------------
2999   py::class_<PyType>(m, "Type", py::module_local())
3000       // Delegate to the PyType copy constructor, which will also lifetime
3001       // extend the backing context which owns the MlirType.
3002       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3003            "Casts the passed type to the generic Type")
3004       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3005       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3006       .def_static(
3007           "parse",
3008           [](std::string typeSpec, DefaultingPyMlirContext context) {
3009             MlirType type =
3010                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3011             // TODO: Rework error reporting once diagnostic engine is exposed
3012             // in C API.
3013             if (mlirTypeIsNull(type)) {
3014               throw SetPyError(PyExc_ValueError,
3015                                Twine("Unable to parse type: '") + typeSpec +
3016                                    "'");
3017             }
3018             return PyType(context->getRef(), type);
3019           },
3020           py::arg("asm"), py::arg("context") = py::none(),
3021           kContextParseTypeDocstring)
3022       .def_property_readonly(
3023           "context", [](PyType &self) { return self.getContext().getObject(); },
3024           "Context that owns the Type")
3025       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3026       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3027       .def("__hash__",
3028            [](PyType &self) {
3029              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3030            })
3031       .def(
3032           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3033       .def(
3034           "__str__",
3035           [](PyType &self) {
3036             PyPrintAccumulator printAccum;
3037             mlirTypePrint(self, printAccum.getCallback(),
3038                           printAccum.getUserData());
3039             return printAccum.join();
3040           },
3041           "Returns the assembly form of the type.")
3042       .def("__repr__", [](PyType &self) {
3043         // Generally, assembly formats are not printed for __repr__ because
3044         // this can cause exceptionally long debug output and exceptions.
3045         // However, types are an exception as they typically have compact
3046         // assembly forms and printing them is useful.
3047         PyPrintAccumulator printAccum;
3048         printAccum.parts.append("Type(");
3049         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3050         printAccum.parts.append(")");
3051         return printAccum.join();
3052       });
3053 
3054   //----------------------------------------------------------------------------
3055   // Mapping of Value.
3056   //----------------------------------------------------------------------------
3057   py::class_<PyValue>(m, "Value", py::module_local())
3058       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3059       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3060       .def_property_readonly(
3061           "context",
3062           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3063           "Context in which the value lives.")
3064       .def(
3065           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3066           kDumpDocstring)
3067       .def_property_readonly(
3068           "owner",
3069           [](PyValue &self) {
3070             assert(mlirOperationEqual(self.getParentOperation()->get(),
3071                                       mlirOpResultGetOwner(self.get())) &&
3072                    "expected the owner of the value in Python to match that in "
3073                    "the IR");
3074             return self.getParentOperation().getObject();
3075           })
3076       .def("__eq__",
3077            [](PyValue &self, PyValue &other) {
3078              return self.get().ptr == other.get().ptr;
3079            })
3080       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3081       .def("__hash__",
3082            [](PyValue &self) {
3083              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3084            })
3085       .def(
3086           "__str__",
3087           [](PyValue &self) {
3088             PyPrintAccumulator printAccum;
3089             printAccum.parts.append("Value(");
3090             mlirValuePrint(self.get(), printAccum.getCallback(),
3091                            printAccum.getUserData());
3092             printAccum.parts.append(")");
3093             return printAccum.join();
3094           },
3095           kValueDunderStrDocstring)
3096       .def_property_readonly("type", [](PyValue &self) {
3097         return PyType(self.getParentOperation()->getContext(),
3098                       mlirValueGetType(self.get()));
3099       });
3100   PyBlockArgument::bind(m);
3101   PyOpResult::bind(m);
3102 
3103   //----------------------------------------------------------------------------
3104   // Mapping of SymbolTable.
3105   //----------------------------------------------------------------------------
3106   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3107       .def(py::init<PyOperationBase &>())
3108       .def("__getitem__", &PySymbolTable::dunderGetItem)
3109       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3110       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3111       .def("__delitem__", &PySymbolTable::dunderDel)
3112       .def("__contains__",
3113            [](PySymbolTable &table, const std::string &name) {
3114              return !mlirOperationIsNull(mlirSymbolTableLookup(
3115                  table, mlirStringRefCreate(name.data(), name.length())));
3116            })
3117       // Static helpers.
3118       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3119                   py::arg("symbol"), py::arg("name"))
3120       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3121                   py::arg("symbol"))
3122       .def_static("get_visibility", &PySymbolTable::getVisibility,
3123                   py::arg("symbol"))
3124       .def_static("set_visibility", &PySymbolTable::setVisibility,
3125                   py::arg("symbol"), py::arg("visibility"))
3126       .def_static("replace_all_symbol_uses",
3127                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3128                   py::arg("new_symbol"), py::arg("from_op"))
3129       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3130                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3131                   py::arg("callback"));
3132 
3133   // Container bindings.
3134   PyBlockArgumentList::bind(m);
3135   PyBlockIterator::bind(m);
3136   PyBlockList::bind(m);
3137   PyOperationIterator::bind(m);
3138   PyOperationList::bind(m);
3139   PyOpAttributeMap::bind(m);
3140   PyOpOperandList::bind(m);
3141   PyOpResultList::bind(m);
3142   PyRegionIterator::bind(m);
3143   PyRegionList::bind(m);
3144 
3145   // Debug bindings.
3146   PyGlobalDebugFlag::bind(m);
3147 }
3148