1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 //#include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 #include <utility>
24 
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32 
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36 
37 static const char kContextParseTypeDocstring[] =
38     R"(Parses the assembly form of a type.
39 
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41 
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44 
45 static const char kContextGetCallSiteLocationDocstring[] =
46     R"(Gets a Location representing a caller and callsite)";
47 
48 static const char kContextGetFileLocationDocstring[] =
49     R"(Gets a Location representing a file, line and column)";
50 
51 static const char kContextGetFusedLocationDocstring[] =
52     R"(Gets a Location representing a fused location with optional metadata)";
53 
54 static const char kContextGetNameLocationDocString[] =
55     R"(Gets a Location representing a named location with optional child location)";
56 
57 static const char kModuleParseDocstring[] =
58     R"(Parses a module's assembly format from a string.
59 
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61 
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64 
65 static const char kOperationCreateDocstring[] =
66     R"(Creates a new operation.
67 
68 Args:
69   name: Operation name (e.g. "dialect.operation").
70   results: Sequence of Type representing op result types.
71   attributes: Dict of str:Attribute.
72   successors: List of Block for the operation's successors.
73   regions: Number of regions to create.
74   location: A Location object (defaults to resolve from context manager).
75   ip: An InsertionPoint (defaults to resolve from context manager or set to
76     False to disable insertion, even with an insertion point set in the
77     context manager).
78 Returns:
79   A new "detached" Operation object. Detached operations can be added
80   to blocks, which causes them to become "attached."
81 )";
82 
83 static const char kOperationPrintDocstring[] =
84     R"(Prints the assembly form of the operation to a file like object.
85 
86 Args:
87   file: The file like object to write to. Defaults to sys.stdout.
88   binary: Whether to write bytes (True) or str (False). Defaults to False.
89   large_elements_limit: Whether to elide elements attributes above this
90     number of elements. Defaults to None (no limit).
91   enable_debug_info: Whether to print debug/location information. Defaults
92     to False.
93   pretty_debug_info: Whether to format debug information for easier reading
94     by a human (warning: the result is unparseable).
95   print_generic_op_form: Whether to print the generic assembly forms of all
96     ops. Defaults to False.
97   use_local_Scope: Whether to print in a way that is more optimized for
98     multi-threaded access but may not be consistent with how the overall
99     module prints.
100   assume_verified: By default, if not printing generic form, the verifier
101     will be run and if it fails, generic form will be printed with a comment
102     about failed verification. While a reasonable default for interactive use,
103     for systematic use, it is often better for the caller to verify explicitly
104     and report failures in a more robust fashion. Set this to True if doing this
105     in order to avoid running a redundant verification. If the IR is actually
106     invalid, behavior is undefined.
107 )";
108 
109 static const char kOperationGetAsmDocstring[] =
110     R"(Gets the assembly form of the operation with all options available.
111 
112 Args:
113   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114     False.
115   ... others ...: See the print() method for common keyword arguments for
116     configuring the printout.
117 Returns:
118   Either a bytes or str object, depending on the setting of the 'binary'
119   argument.
120 )";
121 
122 static const char kOperationStrDunderDocstring[] =
123     R"(Gets the assembly form of the operation with default options.
124 
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129 
130 static const char kDumpDocstring[] =
131     R"(Dumps a debug representation of the object to stderr.)";
132 
133 static const char kAppendBlockDocstring[] =
134     R"(Appends a new block, with argument types as positional args.
135 
136 Returns:
137   The created block.
138 )";
139 
140 static const char kValueDunderStrDocstring[] =
141     R"(Returns the string form of the value.
142 
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147 
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151 
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
154 py::object classmethod(Func f, Args... args) {
155   py::object cf = py::cpp_function(f, args...);
156   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158 
159 static py::object
160 createCustomDialectWrapper(const std::string &dialectNamespace,
161                            py::object dialectDescriptor) {
162   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163   if (!dialectClass) {
164     // Use the base class.
165     return py::cast(PyDialect(std::move(dialectDescriptor)));
166   }
167 
168   // Create the custom implementation.
169   return (*dialectClass)(std::move(dialectDescriptor));
170 }
171 
172 static MlirStringRef toMlirStringRef(const std::string &s) {
173   return mlirStringRefCreate(s.data(), s.size());
174 }
175 
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
178   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179 
180   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181 
182   static void bind(py::module &m) {
183     // Debug flags.
184     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185         .def_property_static("flag", &PyGlobalDebugFlag::get,
186                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187   }
188 };
189 
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193 
194 namespace {
195 
196 class PyRegionIterator {
197 public:
198   PyRegionIterator(PyOperationRef operation)
199       : operation(std::move(operation)) {}
200 
201   PyRegionIterator &dunderIter() { return *this; }
202 
203   PyRegion dunderNext() {
204     operation->checkValid();
205     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206       throw py::stop_iteration();
207     }
208     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209     return PyRegion(operation, region);
210   }
211 
212   static void bind(py::module &m) {
213     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214         .def("__iter__", &PyRegionIterator::dunderIter)
215         .def("__next__", &PyRegionIterator::dunderNext);
216   }
217 
218 private:
219   PyOperationRef operation;
220   int nextIndex = 0;
221 };
222 
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
227   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228 
229   intptr_t dunderLen() {
230     operation->checkValid();
231     return mlirOperationGetNumRegions(operation->get());
232   }
233 
234   PyRegion dunderGetItem(intptr_t index) {
235     // dunderLen checks validity.
236     if (index < 0 || index >= dunderLen()) {
237       throw SetPyError(PyExc_IndexError,
238                        "attempt to access out of bounds region");
239     }
240     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241     return PyRegion(operation, region);
242   }
243 
244   static void bind(py::module &m) {
245     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246         .def("__len__", &PyRegionList::dunderLen)
247         .def("__getitem__", &PyRegionList::dunderGetItem);
248   }
249 
250 private:
251   PyOperationRef operation;
252 };
253 
254 class PyBlockIterator {
255 public:
256   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257       : operation(std::move(operation)), next(next) {}
258 
259   PyBlockIterator &dunderIter() { return *this; }
260 
261   PyBlock dunderNext() {
262     operation->checkValid();
263     if (mlirBlockIsNull(next)) {
264       throw py::stop_iteration();
265     }
266 
267     PyBlock returnBlock(operation, next);
268     next = mlirBlockGetNextInRegion(next);
269     return returnBlock;
270   }
271 
272   static void bind(py::module &m) {
273     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274         .def("__iter__", &PyBlockIterator::dunderIter)
275         .def("__next__", &PyBlockIterator::dunderNext);
276   }
277 
278 private:
279   PyOperationRef operation;
280   MlirBlock next;
281 };
282 
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
288   PyBlockList(PyOperationRef operation, MlirRegion region)
289       : operation(std::move(operation)), region(region) {}
290 
291   PyBlockIterator dunderIter() {
292     operation->checkValid();
293     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294   }
295 
296   intptr_t dunderLen() {
297     operation->checkValid();
298     intptr_t count = 0;
299     MlirBlock block = mlirRegionGetFirstBlock(region);
300     while (!mlirBlockIsNull(block)) {
301       count += 1;
302       block = mlirBlockGetNextInRegion(block);
303     }
304     return count;
305   }
306 
307   PyBlock dunderGetItem(intptr_t index) {
308     operation->checkValid();
309     if (index < 0) {
310       throw SetPyError(PyExc_IndexError,
311                        "attempt to access out of bounds block");
312     }
313     MlirBlock block = mlirRegionGetFirstBlock(region);
314     while (!mlirBlockIsNull(block)) {
315       if (index == 0) {
316         return PyBlock(operation, block);
317       }
318       block = mlirBlockGetNextInRegion(block);
319       index -= 1;
320     }
321     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322   }
323 
324   PyBlock appendBlock(const py::args &pyArgTypes) {
325     operation->checkValid();
326     llvm::SmallVector<MlirType, 4> argTypes;
327     llvm::SmallVector<MlirLocation, 4> argLocs;
328     argTypes.reserve(pyArgTypes.size());
329     argLocs.reserve(pyArgTypes.size());
330     for (auto &pyArg : pyArgTypes) {
331       argTypes.push_back(pyArg.cast<PyType &>());
332       // TODO: Pass in a proper location here.
333       argLocs.push_back(
334           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335     }
336 
337     MlirBlock block =
338         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339     mlirRegionAppendOwnedBlock(region, block);
340     return PyBlock(operation, block);
341   }
342 
343   static void bind(py::module &m) {
344     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345         .def("__getitem__", &PyBlockList::dunderGetItem)
346         .def("__iter__", &PyBlockList::dunderIter)
347         .def("__len__", &PyBlockList::dunderLen)
348         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349   }
350 
351 private:
352   PyOperationRef operation;
353   MlirRegion region;
354 };
355 
356 class PyOperationIterator {
357 public:
358   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359       : parentOperation(std::move(parentOperation)), next(next) {}
360 
361   PyOperationIterator &dunderIter() { return *this; }
362 
363   py::object dunderNext() {
364     parentOperation->checkValid();
365     if (mlirOperationIsNull(next)) {
366       throw py::stop_iteration();
367     }
368 
369     PyOperationRef returnOperation =
370         PyOperation::forOperation(parentOperation->getContext(), next);
371     next = mlirOperationGetNextInBlock(next);
372     return returnOperation->createOpView();
373   }
374 
375   static void bind(py::module &m) {
376     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377         .def("__iter__", &PyOperationIterator::dunderIter)
378         .def("__next__", &PyOperationIterator::dunderNext);
379   }
380 
381 private:
382   PyOperationRef parentOperation;
383   MlirOperation next;
384 };
385 
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
392   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393       : parentOperation(std::move(parentOperation)), block(block) {}
394 
395   PyOperationIterator dunderIter() {
396     parentOperation->checkValid();
397     return PyOperationIterator(parentOperation,
398                                mlirBlockGetFirstOperation(block));
399   }
400 
401   intptr_t dunderLen() {
402     parentOperation->checkValid();
403     intptr_t count = 0;
404     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405     while (!mlirOperationIsNull(childOp)) {
406       count += 1;
407       childOp = mlirOperationGetNextInBlock(childOp);
408     }
409     return count;
410   }
411 
412   py::object dunderGetItem(intptr_t index) {
413     parentOperation->checkValid();
414     if (index < 0) {
415       throw SetPyError(PyExc_IndexError,
416                        "attempt to access out of bounds operation");
417     }
418     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419     while (!mlirOperationIsNull(childOp)) {
420       if (index == 0) {
421         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422             ->createOpView();
423       }
424       childOp = mlirOperationGetNextInBlock(childOp);
425       index -= 1;
426     }
427     throw SetPyError(PyExc_IndexError,
428                      "attempt to access out of bounds operation");
429   }
430 
431   static void bind(py::module &m) {
432     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433         .def("__getitem__", &PyOperationList::dunderGetItem)
434         .def("__iter__", &PyOperationList::dunderIter)
435         .def("__len__", &PyOperationList::dunderLen);
436   }
437 
438 private:
439   PyOperationRef parentOperation;
440   MlirBlock block;
441 };
442 
443 } // namespace
444 
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448 
449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450   py::gil_scoped_acquire acquire;
451   auto &liveContexts = getLiveContexts();
452   liveContexts[context.ptr] = this;
453 }
454 
455 PyMlirContext::~PyMlirContext() {
456   // Note that the only public way to construct an instance is via the
457   // forContext method, which always puts the associated handle into
458   // liveContexts.
459   py::gil_scoped_acquire acquire;
460   getLiveContexts().erase(context.ptr);
461   mlirContextDestroy(context);
462 }
463 
464 py::object PyMlirContext::getCapsule() {
465   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467 
468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470   if (mlirContextIsNull(rawContext))
471     throw py::error_already_set();
472   return forContext(rawContext).releaseObject();
473 }
474 
475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476   MlirContext context = mlirContextCreate();
477   return new PyMlirContext(context);
478 }
479 
480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
481   py::gil_scoped_acquire acquire;
482   auto &liveContexts = getLiveContexts();
483   auto it = liveContexts.find(context.ptr);
484   if (it == liveContexts.end()) {
485     // Create.
486     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
487     py::object pyRef = py::cast(unownedContextWrapper);
488     assert(pyRef && "cast to py::object failed");
489     liveContexts[context.ptr] = unownedContextWrapper;
490     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
491   }
492   // Use existing.
493   py::object pyRef = py::cast(it->second);
494   return PyMlirContextRef(it->second, std::move(pyRef));
495 }
496 
497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
498   static LiveContextMap liveContexts;
499   return liveContexts;
500 }
501 
502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
503 
504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
505 
506 size_t PyMlirContext::clearLiveOperations() {
507   for (auto &op : liveOperations)
508     op.second.second->setInvalid();
509   size_t numInvalidated = liveOperations.size();
510   liveOperations.clear();
511   return numInvalidated;
512 }
513 
514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
515 
516 pybind11::object PyMlirContext::contextEnter() {
517   return PyThreadContextEntry::pushContext(*this);
518 }
519 
520 void PyMlirContext::contextExit(const pybind11::object &excType,
521                                 const pybind11::object &excVal,
522                                 const pybind11::object &excTb) {
523   PyThreadContextEntry::popContext(*this);
524 }
525 
526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
527   // Note that ownership is transferred to the delete callback below by way of
528   // an explicit inc_ref (borrow).
529   PyDiagnosticHandler *pyHandler =
530       new PyDiagnosticHandler(get(), std::move(callback));
531   py::object pyHandlerObject =
532       py::cast(pyHandler, py::return_value_policy::take_ownership);
533   pyHandlerObject.inc_ref();
534 
535   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
536   // guaranteed to be known to pybind.
537   auto handlerCallback =
538       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
539     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
540     py::object pyDiagnosticObject =
541         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
542 
543     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
544     bool result = false;
545     {
546       // Since this can be called from arbitrary C++ contexts, always get the
547       // gil.
548       py::gil_scoped_acquire gil;
549       try {
550         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
551       } catch (std::exception &e) {
552         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
553                 e.what());
554         pyHandler->hadError = true;
555       }
556     }
557 
558     pyDiagnostic->invalidate();
559     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
560   };
561   auto deleteCallback = +[](void *userData) {
562     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
563     assert(pyHandler->registeredID && "handler is not registered");
564     pyHandler->registeredID.reset();
565 
566     // Decrement reference, balancing the inc_ref() above.
567     py::object pyHandlerObject =
568         py::cast(pyHandler, py::return_value_policy::reference);
569     pyHandlerObject.dec_ref();
570   };
571 
572   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
573       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
574   return pyHandlerObject;
575 }
576 
577 PyMlirContext &DefaultingPyMlirContext::resolve() {
578   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
579   if (!context) {
580     throw SetPyError(
581         PyExc_RuntimeError,
582         "An MLIR function requires a Context but none was provided in the call "
583         "or from the surrounding environment. Either pass to the function with "
584         "a 'context=' argument or establish a default using 'with Context():'");
585   }
586   return *context;
587 }
588 
589 //------------------------------------------------------------------------------
590 // PyThreadContextEntry management
591 //------------------------------------------------------------------------------
592 
593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
594   static thread_local std::vector<PyThreadContextEntry> stack;
595   return stack;
596 }
597 
598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
599   auto &stack = getStack();
600   if (stack.empty())
601     return nullptr;
602   return &stack.back();
603 }
604 
605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
606                                 py::object insertionPoint,
607                                 py::object location) {
608   auto &stack = getStack();
609   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
610                      std::move(location));
611   // If the new stack has more than one entry and the context of the new top
612   // entry matches the previous, copy the insertionPoint and location from the
613   // previous entry if missing from the new top entry.
614   if (stack.size() > 1) {
615     auto &prev = *(stack.rbegin() + 1);
616     auto &current = stack.back();
617     if (current.context.is(prev.context)) {
618       // Default non-context objects from the previous entry.
619       if (!current.insertionPoint)
620         current.insertionPoint = prev.insertionPoint;
621       if (!current.location)
622         current.location = prev.location;
623     }
624   }
625 }
626 
627 PyMlirContext *PyThreadContextEntry::getContext() {
628   if (!context)
629     return nullptr;
630   return py::cast<PyMlirContext *>(context);
631 }
632 
633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
634   if (!insertionPoint)
635     return nullptr;
636   return py::cast<PyInsertionPoint *>(insertionPoint);
637 }
638 
639 PyLocation *PyThreadContextEntry::getLocation() {
640   if (!location)
641     return nullptr;
642   return py::cast<PyLocation *>(location);
643 }
644 
645 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
646   auto *tos = getTopOfStack();
647   return tos ? tos->getContext() : nullptr;
648 }
649 
650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
651   auto *tos = getTopOfStack();
652   return tos ? tos->getInsertionPoint() : nullptr;
653 }
654 
655 PyLocation *PyThreadContextEntry::getDefaultLocation() {
656   auto *tos = getTopOfStack();
657   return tos ? tos->getLocation() : nullptr;
658 }
659 
660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
661   py::object contextObj = py::cast(context);
662   push(FrameKind::Context, /*context=*/contextObj,
663        /*insertionPoint=*/py::object(),
664        /*location=*/py::object());
665   return contextObj;
666 }
667 
668 void PyThreadContextEntry::popContext(PyMlirContext &context) {
669   auto &stack = getStack();
670   if (stack.empty())
671     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
672   auto &tos = stack.back();
673   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
674     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
675   stack.pop_back();
676 }
677 
678 py::object
679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
680   py::object contextObj =
681       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
682   py::object insertionPointObj = py::cast(insertionPoint);
683   push(FrameKind::InsertionPoint,
684        /*context=*/contextObj,
685        /*insertionPoint=*/insertionPointObj,
686        /*location=*/py::object());
687   return insertionPointObj;
688 }
689 
690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
691   auto &stack = getStack();
692   if (stack.empty())
693     throw SetPyError(PyExc_RuntimeError,
694                      "Unbalanced InsertionPoint enter/exit");
695   auto &tos = stack.back();
696   if (tos.frameKind != FrameKind::InsertionPoint &&
697       tos.getInsertionPoint() != &insertionPoint)
698     throw SetPyError(PyExc_RuntimeError,
699                      "Unbalanced InsertionPoint enter/exit");
700   stack.pop_back();
701 }
702 
703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
704   py::object contextObj = location.getContext().getObject();
705   py::object locationObj = py::cast(location);
706   push(FrameKind::Location, /*context=*/contextObj,
707        /*insertionPoint=*/py::object(),
708        /*location=*/locationObj);
709   return locationObj;
710 }
711 
712 void PyThreadContextEntry::popLocation(PyLocation &location) {
713   auto &stack = getStack();
714   if (stack.empty())
715     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
716   auto &tos = stack.back();
717   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
718     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
719   stack.pop_back();
720 }
721 
722 //------------------------------------------------------------------------------
723 // PyDiagnostic*
724 //------------------------------------------------------------------------------
725 
726 void PyDiagnostic::invalidate() {
727   valid = false;
728   if (materializedNotes) {
729     for (auto &noteObject : *materializedNotes) {
730       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
731       note->invalidate();
732     }
733   }
734 }
735 
736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
737                                          py::object callback)
738     : context(context), callback(std::move(callback)) {}
739 
740 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
741 
742 void PyDiagnosticHandler::detach() {
743   if (!registeredID)
744     return;
745   MlirDiagnosticHandlerID localID = *registeredID;
746   mlirContextDetachDiagnosticHandler(context, localID);
747   assert(!registeredID && "should have unregistered");
748   // Not strictly necessary but keeps stale pointers from being around to cause
749   // issues.
750   context = {nullptr};
751 }
752 
753 void PyDiagnostic::checkValid() {
754   if (!valid) {
755     throw std::invalid_argument(
756         "Diagnostic is invalid (used outside of callback)");
757   }
758 }
759 
760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
761   checkValid();
762   return mlirDiagnosticGetSeverity(diagnostic);
763 }
764 
765 PyLocation PyDiagnostic::getLocation() {
766   checkValid();
767   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
768   MlirContext context = mlirLocationGetContext(loc);
769   return PyLocation(PyMlirContext::forContext(context), loc);
770 }
771 
772 py::str PyDiagnostic::getMessage() {
773   checkValid();
774   py::object fileObject = py::module::import("io").attr("StringIO")();
775   PyFileAccumulator accum(fileObject, /*binary=*/false);
776   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
777   return fileObject.attr("getvalue")();
778 }
779 
780 py::tuple PyDiagnostic::getNotes() {
781   checkValid();
782   if (materializedNotes)
783     return *materializedNotes;
784   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
785   materializedNotes = py::tuple(numNotes);
786   for (intptr_t i = 0; i < numNotes; ++i) {
787     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
788     materializedNotes.value()[i] = PyDiagnostic(noteDiag);
789   }
790   return *materializedNotes;
791 }
792 
793 //------------------------------------------------------------------------------
794 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
795 //------------------------------------------------------------------------------
796 
797 MlirDialect PyDialects::getDialectForKey(const std::string &key,
798                                          bool attrError) {
799   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
800                                                     {key.data(), key.size()});
801   if (mlirDialectIsNull(dialect)) {
802     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
803                      Twine("Dialect '") + key + "' not found");
804   }
805   return dialect;
806 }
807 
808 py::object PyDialectRegistry::getCapsule() {
809   return py::reinterpret_steal<py::object>(
810       mlirPythonDialectRegistryToCapsule(*this));
811 }
812 
813 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
814   MlirDialectRegistry rawRegistry =
815       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
816   if (mlirDialectRegistryIsNull(rawRegistry))
817     throw py::error_already_set();
818   return PyDialectRegistry(rawRegistry);
819 }
820 
821 //------------------------------------------------------------------------------
822 // PyLocation
823 //------------------------------------------------------------------------------
824 
825 py::object PyLocation::getCapsule() {
826   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
827 }
828 
829 PyLocation PyLocation::createFromCapsule(py::object capsule) {
830   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
831   if (mlirLocationIsNull(rawLoc))
832     throw py::error_already_set();
833   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
834                     rawLoc);
835 }
836 
837 py::object PyLocation::contextEnter() {
838   return PyThreadContextEntry::pushLocation(*this);
839 }
840 
841 void PyLocation::contextExit(const pybind11::object &excType,
842                              const pybind11::object &excVal,
843                              const pybind11::object &excTb) {
844   PyThreadContextEntry::popLocation(*this);
845 }
846 
847 PyLocation &DefaultingPyLocation::resolve() {
848   auto *location = PyThreadContextEntry::getDefaultLocation();
849   if (!location) {
850     throw SetPyError(
851         PyExc_RuntimeError,
852         "An MLIR function requires a Location but none was provided in the "
853         "call or from the surrounding environment. Either pass to the function "
854         "with a 'loc=' argument or establish a default using 'with loc:'");
855   }
856   return *location;
857 }
858 
859 //------------------------------------------------------------------------------
860 // PyModule
861 //------------------------------------------------------------------------------
862 
863 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
864     : BaseContextObject(std::move(contextRef)), module(module) {}
865 
866 PyModule::~PyModule() {
867   py::gil_scoped_acquire acquire;
868   auto &liveModules = getContext()->liveModules;
869   assert(liveModules.count(module.ptr) == 1 &&
870          "destroying module not in live map");
871   liveModules.erase(module.ptr);
872   mlirModuleDestroy(module);
873 }
874 
875 PyModuleRef PyModule::forModule(MlirModule module) {
876   MlirContext context = mlirModuleGetContext(module);
877   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
878 
879   py::gil_scoped_acquire acquire;
880   auto &liveModules = contextRef->liveModules;
881   auto it = liveModules.find(module.ptr);
882   if (it == liveModules.end()) {
883     // Create.
884     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
885     // Note that the default return value policy on cast is automatic_reference,
886     // which does not take ownership (delete will not be called).
887     // Just be explicit.
888     py::object pyRef =
889         py::cast(unownedModule, py::return_value_policy::take_ownership);
890     unownedModule->handle = pyRef;
891     liveModules[module.ptr] =
892         std::make_pair(unownedModule->handle, unownedModule);
893     return PyModuleRef(unownedModule, std::move(pyRef));
894   }
895   // Use existing.
896   PyModule *existing = it->second.second;
897   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
898   return PyModuleRef(existing, std::move(pyRef));
899 }
900 
901 py::object PyModule::createFromCapsule(py::object capsule) {
902   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
903   if (mlirModuleIsNull(rawModule))
904     throw py::error_already_set();
905   return forModule(rawModule).releaseObject();
906 }
907 
908 py::object PyModule::getCapsule() {
909   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
910 }
911 
912 //------------------------------------------------------------------------------
913 // PyOperation
914 //------------------------------------------------------------------------------
915 
916 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
917     : BaseContextObject(std::move(contextRef)), operation(operation) {}
918 
919 PyOperation::~PyOperation() {
920   // If the operation has already been invalidated there is nothing to do.
921   if (!valid)
922     return;
923   auto &liveOperations = getContext()->liveOperations;
924   assert(liveOperations.count(operation.ptr) == 1 &&
925          "destroying operation not in live map");
926   liveOperations.erase(operation.ptr);
927   if (!isAttached()) {
928     mlirOperationDestroy(operation);
929   }
930 }
931 
932 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
933                                            MlirOperation operation,
934                                            py::object parentKeepAlive) {
935   auto &liveOperations = contextRef->liveOperations;
936   // Create.
937   PyOperation *unownedOperation =
938       new PyOperation(std::move(contextRef), operation);
939   // Note that the default return value policy on cast is automatic_reference,
940   // which does not take ownership (delete will not be called).
941   // Just be explicit.
942   py::object pyRef =
943       py::cast(unownedOperation, py::return_value_policy::take_ownership);
944   unownedOperation->handle = pyRef;
945   if (parentKeepAlive) {
946     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
947   }
948   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
949   return PyOperationRef(unownedOperation, std::move(pyRef));
950 }
951 
952 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
953                                          MlirOperation operation,
954                                          py::object parentKeepAlive) {
955   auto &liveOperations = contextRef->liveOperations;
956   auto it = liveOperations.find(operation.ptr);
957   if (it == liveOperations.end()) {
958     // Create.
959     return createInstance(std::move(contextRef), operation,
960                           std::move(parentKeepAlive));
961   }
962   // Use existing.
963   PyOperation *existing = it->second.second;
964   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
965   return PyOperationRef(existing, std::move(pyRef));
966 }
967 
968 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
969                                            MlirOperation operation,
970                                            py::object parentKeepAlive) {
971   auto &liveOperations = contextRef->liveOperations;
972   assert(liveOperations.count(operation.ptr) == 0 &&
973          "cannot create detached operation that already exists");
974   (void)liveOperations;
975 
976   PyOperationRef created = createInstance(std::move(contextRef), operation,
977                                           std::move(parentKeepAlive));
978   created->attached = false;
979   return created;
980 }
981 
982 void PyOperation::checkValid() const {
983   if (!valid) {
984     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
985   }
986 }
987 
988 void PyOperationBase::print(py::object fileObject, bool binary,
989                             llvm::Optional<int64_t> largeElementsLimit,
990                             bool enableDebugInfo, bool prettyDebugInfo,
991                             bool printGenericOpForm, bool useLocalScope,
992                             bool assumeVerified) {
993   PyOperation &operation = getOperation();
994   operation.checkValid();
995   if (fileObject.is_none())
996     fileObject = py::module::import("sys").attr("stdout");
997 
998   if (!assumeVerified && !printGenericOpForm &&
999       !mlirOperationVerify(operation)) {
1000     std::string message("// Verification failed, printing generic form\n");
1001     if (binary) {
1002       fileObject.attr("write")(py::bytes(message));
1003     } else {
1004       fileObject.attr("write")(py::str(message));
1005     }
1006     printGenericOpForm = true;
1007   }
1008 
1009   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1010   if (largeElementsLimit)
1011     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1012   if (enableDebugInfo)
1013     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1014   if (printGenericOpForm)
1015     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1016   if (useLocalScope)
1017     mlirOpPrintingFlagsUseLocalScope(flags);
1018 
1019   PyFileAccumulator accum(fileObject, binary);
1020   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1021                               accum.getUserData());
1022   mlirOpPrintingFlagsDestroy(flags);
1023 }
1024 
1025 py::object PyOperationBase::getAsm(bool binary,
1026                                    llvm::Optional<int64_t> largeElementsLimit,
1027                                    bool enableDebugInfo, bool prettyDebugInfo,
1028                                    bool printGenericOpForm, bool useLocalScope,
1029                                    bool assumeVerified) {
1030   py::object fileObject;
1031   if (binary) {
1032     fileObject = py::module::import("io").attr("BytesIO")();
1033   } else {
1034     fileObject = py::module::import("io").attr("StringIO")();
1035   }
1036   print(fileObject, /*binary=*/binary,
1037         /*largeElementsLimit=*/largeElementsLimit,
1038         /*enableDebugInfo=*/enableDebugInfo,
1039         /*prettyDebugInfo=*/prettyDebugInfo,
1040         /*printGenericOpForm=*/printGenericOpForm,
1041         /*useLocalScope=*/useLocalScope,
1042         /*assumeVerified=*/assumeVerified);
1043 
1044   return fileObject.attr("getvalue")();
1045 }
1046 
1047 void PyOperationBase::moveAfter(PyOperationBase &other) {
1048   PyOperation &operation = getOperation();
1049   PyOperation &otherOp = other.getOperation();
1050   operation.checkValid();
1051   otherOp.checkValid();
1052   mlirOperationMoveAfter(operation, otherOp);
1053   operation.parentKeepAlive = otherOp.parentKeepAlive;
1054 }
1055 
1056 void PyOperationBase::moveBefore(PyOperationBase &other) {
1057   PyOperation &operation = getOperation();
1058   PyOperation &otherOp = other.getOperation();
1059   operation.checkValid();
1060   otherOp.checkValid();
1061   mlirOperationMoveBefore(operation, otherOp);
1062   operation.parentKeepAlive = otherOp.parentKeepAlive;
1063 }
1064 
1065 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1066   checkValid();
1067   if (!isAttached())
1068     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1069   MlirOperation operation = mlirOperationGetParentOperation(get());
1070   if (mlirOperationIsNull(operation))
1071     return {};
1072   return PyOperation::forOperation(getContext(), operation);
1073 }
1074 
1075 PyBlock PyOperation::getBlock() {
1076   checkValid();
1077   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1078   MlirBlock block = mlirOperationGetBlock(get());
1079   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1080   assert(parentOperation && "Operation has no parent");
1081   return PyBlock{std::move(*parentOperation), block};
1082 }
1083 
1084 py::object PyOperation::getCapsule() {
1085   checkValid();
1086   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1087 }
1088 
1089 py::object PyOperation::createFromCapsule(py::object capsule) {
1090   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1091   if (mlirOperationIsNull(rawOperation))
1092     throw py::error_already_set();
1093   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1094   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1095       .releaseObject();
1096 }
1097 
1098 static void maybeInsertOperation(PyOperationRef &op,
1099                                  const py::object &maybeIp) {
1100   // InsertPoint active?
1101   if (!maybeIp.is(py::cast(false))) {
1102     PyInsertionPoint *ip;
1103     if (maybeIp.is_none()) {
1104       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1105     } else {
1106       ip = py::cast<PyInsertionPoint *>(maybeIp);
1107     }
1108     if (ip)
1109       ip->insert(*op.get());
1110   }
1111 }
1112 
1113 py::object PyOperation::create(
1114     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1115     llvm::Optional<std::vector<PyValue *>> operands,
1116     llvm::Optional<py::dict> attributes,
1117     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1118     DefaultingPyLocation location, const py::object &maybeIp) {
1119   llvm::SmallVector<MlirValue, 4> mlirOperands;
1120   llvm::SmallVector<MlirType, 4> mlirResults;
1121   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1122   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1123 
1124   // General parameter validation.
1125   if (regions < 0)
1126     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1127 
1128   // Unpack/validate operands.
1129   if (operands) {
1130     mlirOperands.reserve(operands->size());
1131     for (PyValue *operand : *operands) {
1132       if (!operand)
1133         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1134       mlirOperands.push_back(operand->get());
1135     }
1136   }
1137 
1138   // Unpack/validate results.
1139   if (results) {
1140     mlirResults.reserve(results->size());
1141     for (PyType *result : *results) {
1142       // TODO: Verify result type originate from the same context.
1143       if (!result)
1144         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1145       mlirResults.push_back(*result);
1146     }
1147   }
1148   // Unpack/validate attributes.
1149   if (attributes) {
1150     mlirAttributes.reserve(attributes->size());
1151     for (auto &it : *attributes) {
1152       std::string key;
1153       try {
1154         key = it.first.cast<std::string>();
1155       } catch (py::cast_error &err) {
1156         std::string msg = "Invalid attribute key (not a string) when "
1157                           "attempting to create the operation \"" +
1158                           name + "\" (" + err.what() + ")";
1159         throw py::cast_error(msg);
1160       }
1161       try {
1162         auto &attribute = it.second.cast<PyAttribute &>();
1163         // TODO: Verify attribute originates from the same context.
1164         mlirAttributes.emplace_back(std::move(key), attribute);
1165       } catch (py::reference_cast_error &) {
1166         // This exception seems thrown when the value is "None".
1167         std::string msg =
1168             "Found an invalid (`None`?) attribute value for the key \"" + key +
1169             "\" when attempting to create the operation \"" + name + "\"";
1170         throw py::cast_error(msg);
1171       } catch (py::cast_error &err) {
1172         std::string msg = "Invalid attribute value for the key \"" + key +
1173                           "\" when attempting to create the operation \"" +
1174                           name + "\" (" + err.what() + ")";
1175         throw py::cast_error(msg);
1176       }
1177     }
1178   }
1179   // Unpack/validate successors.
1180   if (successors) {
1181     mlirSuccessors.reserve(successors->size());
1182     for (auto *successor : *successors) {
1183       // TODO: Verify successor originate from the same context.
1184       if (!successor)
1185         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1186       mlirSuccessors.push_back(successor->get());
1187     }
1188   }
1189 
1190   // Apply unpacked/validated to the operation state. Beyond this
1191   // point, exceptions cannot be thrown or else the state will leak.
1192   MlirOperationState state =
1193       mlirOperationStateGet(toMlirStringRef(name), location);
1194   if (!mlirOperands.empty())
1195     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1196                                   mlirOperands.data());
1197   if (!mlirResults.empty())
1198     mlirOperationStateAddResults(&state, mlirResults.size(),
1199                                  mlirResults.data());
1200   if (!mlirAttributes.empty()) {
1201     // Note that the attribute names directly reference bytes in
1202     // mlirAttributes, so that vector must not be changed from here
1203     // on.
1204     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1205     mlirNamedAttributes.reserve(mlirAttributes.size());
1206     for (auto &it : mlirAttributes)
1207       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1208           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1209                             toMlirStringRef(it.first)),
1210           it.second));
1211     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1212                                     mlirNamedAttributes.data());
1213   }
1214   if (!mlirSuccessors.empty())
1215     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1216                                     mlirSuccessors.data());
1217   if (regions) {
1218     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1219     mlirRegions.resize(regions);
1220     for (int i = 0; i < regions; ++i)
1221       mlirRegions[i] = mlirRegionCreate();
1222     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1223                                       mlirRegions.data());
1224   }
1225 
1226   // Construct the operation.
1227   MlirOperation operation = mlirOperationCreate(&state);
1228   PyOperationRef created =
1229       PyOperation::createDetached(location->getContext(), operation);
1230   maybeInsertOperation(created, maybeIp);
1231 
1232   return created->createOpView();
1233 }
1234 
1235 py::object PyOperation::clone(const py::object &maybeIp) {
1236   MlirOperation clonedOperation = mlirOperationClone(operation);
1237   PyOperationRef cloned =
1238       PyOperation::createDetached(getContext(), clonedOperation);
1239   maybeInsertOperation(cloned, maybeIp);
1240 
1241   return cloned->createOpView();
1242 }
1243 
1244 py::object PyOperation::createOpView() {
1245   checkValid();
1246   MlirIdentifier ident = mlirOperationGetName(get());
1247   MlirStringRef identStr = mlirIdentifierStr(ident);
1248   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1249       StringRef(identStr.data, identStr.length));
1250   if (opViewClass)
1251     return (*opViewClass)(getRef().getObject());
1252   return py::cast(PyOpView(getRef().getObject()));
1253 }
1254 
1255 void PyOperation::erase() {
1256   checkValid();
1257   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1258   // Python reference to a child operation is live. All children should also
1259   // have their `valid` bit set to false.
1260   auto &liveOperations = getContext()->liveOperations;
1261   if (liveOperations.count(operation.ptr))
1262     liveOperations.erase(operation.ptr);
1263   mlirOperationDestroy(operation);
1264   valid = false;
1265 }
1266 
1267 //------------------------------------------------------------------------------
1268 // PyOpView
1269 //------------------------------------------------------------------------------
1270 
1271 py::object PyOpView::buildGeneric(
1272     const py::object &cls, py::list resultTypeList, py::list operandList,
1273     llvm::Optional<py::dict> attributes,
1274     llvm::Optional<std::vector<PyBlock *>> successors,
1275     llvm::Optional<int> regions, DefaultingPyLocation location,
1276     const py::object &maybeIp) {
1277   PyMlirContextRef context = location->getContext();
1278   // Class level operation construction metadata.
1279   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1280   // Operand and result segment specs are either none, which does no
1281   // variadic unpacking, or a list of ints with segment sizes, where each
1282   // element is either a positive number (typically 1 for a scalar) or -1 to
1283   // indicate that it is derived from the length of the same-indexed operand
1284   // or result (implying that it is a list at that position).
1285   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1286   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1287 
1288   std::vector<uint32_t> operandSegmentLengths;
1289   std::vector<uint32_t> resultSegmentLengths;
1290 
1291   // Validate/determine region count.
1292   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1293   int opMinRegionCount = std::get<0>(opRegionSpec);
1294   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1295   if (!regions) {
1296     regions = opMinRegionCount;
1297   }
1298   if (*regions < opMinRegionCount) {
1299     throw py::value_error(
1300         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1301          llvm::Twine(opMinRegionCount) +
1302          " regions but was built with regions=" + llvm::Twine(*regions))
1303             .str());
1304   }
1305   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1306     throw py::value_error(
1307         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1308          llvm::Twine(opMinRegionCount) +
1309          " regions but was built with regions=" + llvm::Twine(*regions))
1310             .str());
1311   }
1312 
1313   // Unpack results.
1314   std::vector<PyType *> resultTypes;
1315   resultTypes.reserve(resultTypeList.size());
1316   if (resultSegmentSpecObj.is_none()) {
1317     // Non-variadic result unpacking.
1318     for (const auto &it : llvm::enumerate(resultTypeList)) {
1319       try {
1320         resultTypes.push_back(py::cast<PyType *>(it.value()));
1321         if (!resultTypes.back())
1322           throw py::cast_error();
1323       } catch (py::cast_error &err) {
1324         throw py::value_error((llvm::Twine("Result ") +
1325                                llvm::Twine(it.index()) + " of operation \"" +
1326                                name + "\" must be a Type (" + err.what() + ")")
1327                                   .str());
1328       }
1329     }
1330   } else {
1331     // Sized result unpacking.
1332     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1333     if (resultSegmentSpec.size() != resultTypeList.size()) {
1334       throw py::value_error((llvm::Twine("Operation \"") + name +
1335                              "\" requires " +
1336                              llvm::Twine(resultSegmentSpec.size()) +
1337                              " result segments but was provided " +
1338                              llvm::Twine(resultTypeList.size()))
1339                                 .str());
1340     }
1341     resultSegmentLengths.reserve(resultTypeList.size());
1342     for (const auto &it :
1343          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1344       int segmentSpec = std::get<1>(it.value());
1345       if (segmentSpec == 1 || segmentSpec == 0) {
1346         // Unpack unary element.
1347         try {
1348           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1349           if (resultType) {
1350             resultTypes.push_back(resultType);
1351             resultSegmentLengths.push_back(1);
1352           } else if (segmentSpec == 0) {
1353             // Allowed to be optional.
1354             resultSegmentLengths.push_back(0);
1355           } else {
1356             throw py::cast_error("was None and result is not optional");
1357           }
1358         } catch (py::cast_error &err) {
1359           throw py::value_error((llvm::Twine("Result ") +
1360                                  llvm::Twine(it.index()) + " of operation \"" +
1361                                  name + "\" must be a Type (" + err.what() +
1362                                  ")")
1363                                     .str());
1364         }
1365       } else if (segmentSpec == -1) {
1366         // Unpack sequence by appending.
1367         try {
1368           if (std::get<0>(it.value()).is_none()) {
1369             // Treat it as an empty list.
1370             resultSegmentLengths.push_back(0);
1371           } else {
1372             // Unpack the list.
1373             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1374             for (py::object segmentItem : segment) {
1375               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1376               if (!resultTypes.back()) {
1377                 throw py::cast_error("contained a None item");
1378               }
1379             }
1380             resultSegmentLengths.push_back(segment.size());
1381           }
1382         } catch (std::exception &err) {
1383           // NOTE: Sloppy to be using a catch-all here, but there are at least
1384           // three different unrelated exceptions that can be thrown in the
1385           // above "casts". Just keep the scope above small and catch them all.
1386           throw py::value_error((llvm::Twine("Result ") +
1387                                  llvm::Twine(it.index()) + " of operation \"" +
1388                                  name + "\" must be a Sequence of Types (" +
1389                                  err.what() + ")")
1390                                     .str());
1391         }
1392       } else {
1393         throw py::value_error("Unexpected segment spec");
1394       }
1395     }
1396   }
1397 
1398   // Unpack operands.
1399   std::vector<PyValue *> operands;
1400   operands.reserve(operands.size());
1401   if (operandSegmentSpecObj.is_none()) {
1402     // Non-sized operand unpacking.
1403     for (const auto &it : llvm::enumerate(operandList)) {
1404       try {
1405         operands.push_back(py::cast<PyValue *>(it.value()));
1406         if (!operands.back())
1407           throw py::cast_error();
1408       } catch (py::cast_error &err) {
1409         throw py::value_error((llvm::Twine("Operand ") +
1410                                llvm::Twine(it.index()) + " of operation \"" +
1411                                name + "\" must be a Value (" + err.what() + ")")
1412                                   .str());
1413       }
1414     }
1415   } else {
1416     // Sized operand unpacking.
1417     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1418     if (operandSegmentSpec.size() != operandList.size()) {
1419       throw py::value_error((llvm::Twine("Operation \"") + name +
1420                              "\" requires " +
1421                              llvm::Twine(operandSegmentSpec.size()) +
1422                              "operand segments but was provided " +
1423                              llvm::Twine(operandList.size()))
1424                                 .str());
1425     }
1426     operandSegmentLengths.reserve(operandList.size());
1427     for (const auto &it :
1428          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1429       int segmentSpec = std::get<1>(it.value());
1430       if (segmentSpec == 1 || segmentSpec == 0) {
1431         // Unpack unary element.
1432         try {
1433           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1434           if (operandValue) {
1435             operands.push_back(operandValue);
1436             operandSegmentLengths.push_back(1);
1437           } else if (segmentSpec == 0) {
1438             // Allowed to be optional.
1439             operandSegmentLengths.push_back(0);
1440           } else {
1441             throw py::cast_error("was None and operand is not optional");
1442           }
1443         } catch (py::cast_error &err) {
1444           throw py::value_error((llvm::Twine("Operand ") +
1445                                  llvm::Twine(it.index()) + " of operation \"" +
1446                                  name + "\" must be a Value (" + err.what() +
1447                                  ")")
1448                                     .str());
1449         }
1450       } else if (segmentSpec == -1) {
1451         // Unpack sequence by appending.
1452         try {
1453           if (std::get<0>(it.value()).is_none()) {
1454             // Treat it as an empty list.
1455             operandSegmentLengths.push_back(0);
1456           } else {
1457             // Unpack the list.
1458             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1459             for (py::object segmentItem : segment) {
1460               operands.push_back(py::cast<PyValue *>(segmentItem));
1461               if (!operands.back()) {
1462                 throw py::cast_error("contained a None item");
1463               }
1464             }
1465             operandSegmentLengths.push_back(segment.size());
1466           }
1467         } catch (std::exception &err) {
1468           // NOTE: Sloppy to be using a catch-all here, but there are at least
1469           // three different unrelated exceptions that can be thrown in the
1470           // above "casts". Just keep the scope above small and catch them all.
1471           throw py::value_error((llvm::Twine("Operand ") +
1472                                  llvm::Twine(it.index()) + " of operation \"" +
1473                                  name + "\" must be a Sequence of Values (" +
1474                                  err.what() + ")")
1475                                     .str());
1476         }
1477       } else {
1478         throw py::value_error("Unexpected segment spec");
1479       }
1480     }
1481   }
1482 
1483   // Merge operand/result segment lengths into attributes if needed.
1484   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1485     // Dup.
1486     if (attributes) {
1487       attributes = py::dict(*attributes);
1488     } else {
1489       attributes = py::dict();
1490     }
1491     if (attributes->contains("result_segment_sizes") ||
1492         attributes->contains("operand_segment_sizes")) {
1493       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1494                             "'operand_segment_sizes' attribute is unsupported. "
1495                             "Use Operation.create for such low-level access.");
1496     }
1497 
1498     // Add result_segment_sizes attribute.
1499     if (!resultSegmentLengths.empty()) {
1500       int64_t size = resultSegmentLengths.size();
1501       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1502           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1503           resultSegmentLengths.size(), resultSegmentLengths.data());
1504       (*attributes)["result_segment_sizes"] =
1505           PyAttribute(context, segmentLengthAttr);
1506     }
1507 
1508     // Add operand_segment_sizes attribute.
1509     if (!operandSegmentLengths.empty()) {
1510       int64_t size = operandSegmentLengths.size();
1511       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1512           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1513           operandSegmentLengths.size(), operandSegmentLengths.data());
1514       (*attributes)["operand_segment_sizes"] =
1515           PyAttribute(context, segmentLengthAttr);
1516     }
1517   }
1518 
1519   // Delegate to create.
1520   return PyOperation::create(name,
1521                              /*results=*/std::move(resultTypes),
1522                              /*operands=*/std::move(operands),
1523                              /*attributes=*/std::move(attributes),
1524                              /*successors=*/std::move(successors),
1525                              /*regions=*/*regions, location, maybeIp);
1526 }
1527 
1528 PyOpView::PyOpView(const py::object &operationObject)
1529     // Casting through the PyOperationBase base-class and then back to the
1530     // Operation lets us accept any PyOperationBase subclass.
1531     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1532       operationObject(operation.getRef().getObject()) {}
1533 
1534 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1535   // This is... a little gross. The typical pattern is to have a pure python
1536   // class that extends OpView like:
1537   //   class AddFOp(_cext.ir.OpView):
1538   //     def __init__(self, loc, lhs, rhs):
1539   //       operation = loc.context.create_operation(
1540   //           "addf", lhs, rhs, results=[lhs.type])
1541   //       super().__init__(operation)
1542   //
1543   // I.e. The goal of the user facing type is to provide a nice constructor
1544   // that has complete freedom for the op under construction. This is at odds
1545   // with our other desire to sometimes create this object by just passing an
1546   // operation (to initialize the base class). We could do *arg and **kwargs
1547   // munging to try to make it work, but instead, we synthesize a new class
1548   // on the fly which extends this user class (AddFOp in this example) and
1549   // *give it* the base class's __init__ method, thus bypassing the
1550   // intermediate subclass's __init__ method entirely. While slightly,
1551   // underhanded, this is safe/legal because the type hierarchy has not changed
1552   // (we just added a new leaf) and we aren't mucking around with __new__.
1553   // Typically, this new class will be stored on the original as "_Raw" and will
1554   // be used for casts and other things that need a variant of the class that
1555   // is initialized purely from an operation.
1556   py::object parentMetaclass =
1557       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1558   py::dict attributes;
1559   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1560   // now.
1561   //   auto opViewType = py::type::of<PyOpView>();
1562   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1563   attributes["__init__"] = opViewType.attr("__init__");
1564   py::str origName = userClass.attr("__name__");
1565   py::str newName = py::str("_") + origName;
1566   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1567 }
1568 
1569 //------------------------------------------------------------------------------
1570 // PyInsertionPoint.
1571 //------------------------------------------------------------------------------
1572 
1573 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1574 
1575 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1576     : refOperation(beforeOperationBase.getOperation().getRef()),
1577       block((*refOperation)->getBlock()) {}
1578 
1579 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1580   PyOperation &operation = operationBase.getOperation();
1581   if (operation.isAttached())
1582     throw SetPyError(PyExc_ValueError,
1583                      "Attempt to insert operation that is already attached");
1584   block.getParentOperation()->checkValid();
1585   MlirOperation beforeOp = {nullptr};
1586   if (refOperation) {
1587     // Insert before operation.
1588     (*refOperation)->checkValid();
1589     beforeOp = (*refOperation)->get();
1590   } else {
1591     // Insert at end (before null) is only valid if the block does not
1592     // already end in a known terminator (violating this will cause assertion
1593     // failures later).
1594     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1595       throw py::index_error("Cannot insert operation at the end of a block "
1596                             "that already has a terminator. Did you mean to "
1597                             "use 'InsertionPoint.at_block_terminator(block)' "
1598                             "versus 'InsertionPoint(block)'?");
1599     }
1600   }
1601   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1602   operation.setAttached();
1603 }
1604 
1605 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1606   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1607   if (mlirOperationIsNull(firstOp)) {
1608     // Just insert at end.
1609     return PyInsertionPoint(block);
1610   }
1611 
1612   // Insert before first op.
1613   PyOperationRef firstOpRef = PyOperation::forOperation(
1614       block.getParentOperation()->getContext(), firstOp);
1615   return PyInsertionPoint{block, std::move(firstOpRef)};
1616 }
1617 
1618 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1619   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1620   if (mlirOperationIsNull(terminator))
1621     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1622   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1623       block.getParentOperation()->getContext(), terminator);
1624   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1625 }
1626 
1627 py::object PyInsertionPoint::contextEnter() {
1628   return PyThreadContextEntry::pushInsertionPoint(*this);
1629 }
1630 
1631 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1632                                    const pybind11::object &excVal,
1633                                    const pybind11::object &excTb) {
1634   PyThreadContextEntry::popInsertionPoint(*this);
1635 }
1636 
1637 //------------------------------------------------------------------------------
1638 // PyAttribute.
1639 //------------------------------------------------------------------------------
1640 
1641 bool PyAttribute::operator==(const PyAttribute &other) {
1642   return mlirAttributeEqual(attr, other.attr);
1643 }
1644 
1645 py::object PyAttribute::getCapsule() {
1646   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1647 }
1648 
1649 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1650   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1651   if (mlirAttributeIsNull(rawAttr))
1652     throw py::error_already_set();
1653   return PyAttribute(
1654       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1655 }
1656 
1657 //------------------------------------------------------------------------------
1658 // PyNamedAttribute.
1659 //------------------------------------------------------------------------------
1660 
1661 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1662     : ownedName(new std::string(std::move(ownedName))) {
1663   namedAttr = mlirNamedAttributeGet(
1664       mlirIdentifierGet(mlirAttributeGetContext(attr),
1665                         toMlirStringRef(*this->ownedName)),
1666       attr);
1667 }
1668 
1669 //------------------------------------------------------------------------------
1670 // PyType.
1671 //------------------------------------------------------------------------------
1672 
1673 bool PyType::operator==(const PyType &other) {
1674   return mlirTypeEqual(type, other.type);
1675 }
1676 
1677 py::object PyType::getCapsule() {
1678   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1679 }
1680 
1681 PyType PyType::createFromCapsule(py::object capsule) {
1682   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1683   if (mlirTypeIsNull(rawType))
1684     throw py::error_already_set();
1685   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1686                 rawType);
1687 }
1688 
1689 //------------------------------------------------------------------------------
1690 // PyValue and subclases.
1691 //------------------------------------------------------------------------------
1692 
1693 pybind11::object PyValue::getCapsule() {
1694   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1695 }
1696 
1697 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1698   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1699   if (mlirValueIsNull(value))
1700     throw py::error_already_set();
1701   MlirOperation owner;
1702   if (mlirValueIsAOpResult(value))
1703     owner = mlirOpResultGetOwner(value);
1704   if (mlirValueIsABlockArgument(value))
1705     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1706   if (mlirOperationIsNull(owner))
1707     throw py::error_already_set();
1708   MlirContext ctx = mlirOperationGetContext(owner);
1709   PyOperationRef ownerRef =
1710       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1711   return PyValue(ownerRef, value);
1712 }
1713 
1714 //------------------------------------------------------------------------------
1715 // PySymbolTable.
1716 //------------------------------------------------------------------------------
1717 
1718 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1719     : operation(operation.getOperation().getRef()) {
1720   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1721   if (mlirSymbolTableIsNull(symbolTable)) {
1722     throw py::cast_error("Operation is not a Symbol Table.");
1723   }
1724 }
1725 
1726 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1727   operation->checkValid();
1728   MlirOperation symbol = mlirSymbolTableLookup(
1729       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1730   if (mlirOperationIsNull(symbol))
1731     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1732 
1733   return PyOperation::forOperation(operation->getContext(), symbol,
1734                                    operation.getObject())
1735       ->createOpView();
1736 }
1737 
1738 void PySymbolTable::erase(PyOperationBase &symbol) {
1739   operation->checkValid();
1740   symbol.getOperation().checkValid();
1741   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1742   // The operation is also erased, so we must invalidate it. There may be Python
1743   // references to this operation so we don't want to delete it from the list of
1744   // live operations here.
1745   symbol.getOperation().valid = false;
1746 }
1747 
1748 void PySymbolTable::dunderDel(const std::string &name) {
1749   py::object operation = dunderGetItem(name);
1750   erase(py::cast<PyOperationBase &>(operation));
1751 }
1752 
1753 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1754   operation->checkValid();
1755   symbol.getOperation().checkValid();
1756   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1757       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1758   if (mlirAttributeIsNull(symbolAttr))
1759     throw py::value_error("Expected operation to have a symbol name.");
1760   return PyAttribute(
1761       symbol.getOperation().getContext(),
1762       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1763 }
1764 
1765 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1766   // Op must already be a symbol.
1767   PyOperation &operation = symbol.getOperation();
1768   operation.checkValid();
1769   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1770   MlirAttribute existingNameAttr =
1771       mlirOperationGetAttributeByName(operation.get(), attrName);
1772   if (mlirAttributeIsNull(existingNameAttr))
1773     throw py::value_error("Expected operation to have a symbol name.");
1774   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1775 }
1776 
1777 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1778                                   const std::string &name) {
1779   // Op must already be a symbol.
1780   PyOperation &operation = symbol.getOperation();
1781   operation.checkValid();
1782   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1783   MlirAttribute existingNameAttr =
1784       mlirOperationGetAttributeByName(operation.get(), attrName);
1785   if (mlirAttributeIsNull(existingNameAttr))
1786     throw py::value_error("Expected operation to have a symbol name.");
1787   MlirAttribute newNameAttr =
1788       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1789   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1790 }
1791 
1792 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1793   PyOperation &operation = symbol.getOperation();
1794   operation.checkValid();
1795   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1796   MlirAttribute existingVisAttr =
1797       mlirOperationGetAttributeByName(operation.get(), attrName);
1798   if (mlirAttributeIsNull(existingVisAttr))
1799     throw py::value_error("Expected operation to have a symbol visibility.");
1800   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1801 }
1802 
1803 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1804                                   const std::string &visibility) {
1805   if (visibility != "public" && visibility != "private" &&
1806       visibility != "nested")
1807     throw py::value_error(
1808         "Expected visibility to be 'public', 'private' or 'nested'");
1809   PyOperation &operation = symbol.getOperation();
1810   operation.checkValid();
1811   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1812   MlirAttribute existingVisAttr =
1813       mlirOperationGetAttributeByName(operation.get(), attrName);
1814   if (mlirAttributeIsNull(existingVisAttr))
1815     throw py::value_error("Expected operation to have a symbol visibility.");
1816   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1817                                                toMlirStringRef(visibility));
1818   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1819 }
1820 
1821 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1822                                          const std::string &newSymbol,
1823                                          PyOperationBase &from) {
1824   PyOperation &fromOperation = from.getOperation();
1825   fromOperation.checkValid();
1826   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1827           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1828           from.getOperation())))
1829 
1830     throw py::value_error("Symbol rename failed");
1831 }
1832 
1833 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1834                                      bool allSymUsesVisible,
1835                                      py::object callback) {
1836   PyOperation &fromOperation = from.getOperation();
1837   fromOperation.checkValid();
1838   struct UserData {
1839     PyMlirContextRef context;
1840     py::object callback;
1841     bool gotException;
1842     std::string exceptionWhat;
1843     py::object exceptionType;
1844   };
1845   UserData userData{
1846       fromOperation.getContext(), std::move(callback), false, {}, {}};
1847   mlirSymbolTableWalkSymbolTables(
1848       fromOperation.get(), allSymUsesVisible,
1849       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1850         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1851         auto pyFoundOp =
1852             PyOperation::forOperation(calleeUserData->context, foundOp);
1853         if (calleeUserData->gotException)
1854           return;
1855         try {
1856           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1857         } catch (py::error_already_set &e) {
1858           calleeUserData->gotException = true;
1859           calleeUserData->exceptionWhat = e.what();
1860           calleeUserData->exceptionType = e.type();
1861         }
1862       },
1863       static_cast<void *>(&userData));
1864   if (userData.gotException) {
1865     std::string message("Exception raised in callback: ");
1866     message.append(userData.exceptionWhat);
1867     throw std::runtime_error(message);
1868   }
1869 }
1870 
1871 namespace {
1872 /// CRTP base class for Python MLIR values that subclass Value and should be
1873 /// castable from it. The value hierarchy is one level deep and is not supposed
1874 /// to accommodate other levels unless core MLIR changes.
1875 template <typename DerivedTy>
1876 class PyConcreteValue : public PyValue {
1877 public:
1878   // Derived classes must define statics for:
1879   //   IsAFunctionTy isaFunction
1880   //   const char *pyClassName
1881   // and redefine bindDerived.
1882   using ClassTy = py::class_<DerivedTy, PyValue>;
1883   using IsAFunctionTy = bool (*)(MlirValue);
1884 
1885   PyConcreteValue() = default;
1886   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1887       : PyValue(operationRef, value) {}
1888   PyConcreteValue(PyValue &orig)
1889       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1890 
1891   /// Attempts to cast the original value to the derived type and throws on
1892   /// type mismatches.
1893   static MlirValue castFrom(PyValue &orig) {
1894     if (!DerivedTy::isaFunction(orig.get())) {
1895       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1896       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1897                                              DerivedTy::pyClassName +
1898                                              " (from " + origRepr + ")");
1899     }
1900     return orig.get();
1901   }
1902 
1903   /// Binds the Python module objects to functions of this class.
1904   static void bind(py::module &m) {
1905     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1906     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1907     cls.def_static(
1908         "isinstance",
1909         [](PyValue &otherValue) -> bool {
1910           return DerivedTy::isaFunction(otherValue);
1911         },
1912         py::arg("other_value"));
1913     DerivedTy::bindDerived(cls);
1914   }
1915 
1916   /// Implemented by derived classes to add methods to the Python subclass.
1917   static void bindDerived(ClassTy &m) {}
1918 };
1919 
1920 /// Python wrapper for MlirBlockArgument.
1921 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1922 public:
1923   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1924   static constexpr const char *pyClassName = "BlockArgument";
1925   using PyConcreteValue::PyConcreteValue;
1926 
1927   static void bindDerived(ClassTy &c) {
1928     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1929       return PyBlock(self.getParentOperation(),
1930                      mlirBlockArgumentGetOwner(self.get()));
1931     });
1932     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1933       return mlirBlockArgumentGetArgNumber(self.get());
1934     });
1935     c.def(
1936         "set_type",
1937         [](PyBlockArgument &self, PyType type) {
1938           return mlirBlockArgumentSetType(self.get(), type);
1939         },
1940         py::arg("type"));
1941   }
1942 };
1943 
1944 /// Python wrapper for MlirOpResult.
1945 class PyOpResult : public PyConcreteValue<PyOpResult> {
1946 public:
1947   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1948   static constexpr const char *pyClassName = "OpResult";
1949   using PyConcreteValue::PyConcreteValue;
1950 
1951   static void bindDerived(ClassTy &c) {
1952     c.def_property_readonly("owner", [](PyOpResult &self) {
1953       assert(
1954           mlirOperationEqual(self.getParentOperation()->get(),
1955                              mlirOpResultGetOwner(self.get())) &&
1956           "expected the owner of the value in Python to match that in the IR");
1957       return self.getParentOperation().getObject();
1958     });
1959     c.def_property_readonly("result_number", [](PyOpResult &self) {
1960       return mlirOpResultGetResultNumber(self.get());
1961     });
1962   }
1963 };
1964 
1965 /// Returns the list of types of the values held by container.
1966 template <typename Container>
1967 static std::vector<PyType> getValueTypes(Container &container,
1968                                          PyMlirContextRef &context) {
1969   std::vector<PyType> result;
1970   result.reserve(container.size());
1971   for (int i = 0, e = container.size(); i < e; ++i) {
1972     result.push_back(
1973         PyType(context, mlirValueGetType(container.getElement(i).get())));
1974   }
1975   return result;
1976 }
1977 
1978 /// A list of block arguments. Internally, these are stored as consecutive
1979 /// elements, random access is cheap. The argument list is associated with the
1980 /// operation that contains the block (detached blocks are not allowed in
1981 /// Python bindings) and extends its lifetime.
1982 class PyBlockArgumentList
1983     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1984 public:
1985   static constexpr const char *pyClassName = "BlockArgumentList";
1986 
1987   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1988                       intptr_t startIndex = 0, intptr_t length = -1,
1989                       intptr_t step = 1)
1990       : Sliceable(startIndex,
1991                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1992                   step),
1993         operation(std::move(operation)), block(block) {}
1994 
1995   static void bindDerived(ClassTy &c) {
1996     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1997       return getValueTypes(self, self.operation->getContext());
1998     });
1999   }
2000 
2001 private:
2002   /// Give the parent CRTP class access to hook implementations below.
2003   friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2004 
2005   /// Returns the number of arguments in the list.
2006   intptr_t getRawNumElements() {
2007     operation->checkValid();
2008     return mlirBlockGetNumArguments(block);
2009   }
2010 
2011   /// Returns `pos`-the element in the list.
2012   PyBlockArgument getRawElement(intptr_t pos) {
2013     MlirValue argument = mlirBlockGetArgument(block, pos);
2014     return PyBlockArgument(operation, argument);
2015   }
2016 
2017   /// Returns a sublist of this list.
2018   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2019                             intptr_t step) {
2020     return PyBlockArgumentList(operation, block, startIndex, length, step);
2021   }
2022 
2023   PyOperationRef operation;
2024   MlirBlock block;
2025 };
2026 
2027 /// A list of operation operands. Internally, these are stored as consecutive
2028 /// elements, random access is cheap. The result list is associated with the
2029 /// operation whose results these are, and extends the lifetime of this
2030 /// operation.
2031 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2032 public:
2033   static constexpr const char *pyClassName = "OpOperandList";
2034 
2035   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2036                   intptr_t length = -1, intptr_t step = 1)
2037       : Sliceable(startIndex,
2038                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2039                                : length,
2040                   step),
2041         operation(operation) {}
2042 
2043   void dunderSetItem(intptr_t index, PyValue value) {
2044     index = wrapIndex(index);
2045     mlirOperationSetOperand(operation->get(), index, value.get());
2046   }
2047 
2048   static void bindDerived(ClassTy &c) {
2049     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2050   }
2051 
2052 private:
2053   /// Give the parent CRTP class access to hook implementations below.
2054   friend class Sliceable<PyOpOperandList, PyValue>;
2055 
2056   intptr_t getRawNumElements() {
2057     operation->checkValid();
2058     return mlirOperationGetNumOperands(operation->get());
2059   }
2060 
2061   PyValue getRawElement(intptr_t pos) {
2062     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2063     MlirOperation owner;
2064     if (mlirValueIsAOpResult(operand))
2065       owner = mlirOpResultGetOwner(operand);
2066     else if (mlirValueIsABlockArgument(operand))
2067       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2068     else
2069       assert(false && "Value must be an block arg or op result.");
2070     PyOperationRef pyOwner =
2071         PyOperation::forOperation(operation->getContext(), owner);
2072     return PyValue(pyOwner, operand);
2073   }
2074 
2075   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2076     return PyOpOperandList(operation, startIndex, length, step);
2077   }
2078 
2079   PyOperationRef operation;
2080 };
2081 
2082 /// A list of operation results. Internally, these are stored as consecutive
2083 /// elements, random access is cheap. The result list is associated with the
2084 /// operation whose results these are, and extends the lifetime of this
2085 /// operation.
2086 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2087 public:
2088   static constexpr const char *pyClassName = "OpResultList";
2089 
2090   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2091                  intptr_t length = -1, intptr_t step = 1)
2092       : Sliceable(startIndex,
2093                   length == -1 ? mlirOperationGetNumResults(operation->get())
2094                                : length,
2095                   step),
2096         operation(operation) {}
2097 
2098   static void bindDerived(ClassTy &c) {
2099     c.def_property_readonly("types", [](PyOpResultList &self) {
2100       return getValueTypes(self, self.operation->getContext());
2101     });
2102   }
2103 
2104 private:
2105   /// Give the parent CRTP class access to hook implementations below.
2106   friend class Sliceable<PyOpResultList, PyOpResult>;
2107 
2108   intptr_t getRawNumElements() {
2109     operation->checkValid();
2110     return mlirOperationGetNumResults(operation->get());
2111   }
2112 
2113   PyOpResult getRawElement(intptr_t index) {
2114     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2115     return PyOpResult(value);
2116   }
2117 
2118   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2119     return PyOpResultList(operation, startIndex, length, step);
2120   }
2121 
2122   PyOperationRef operation;
2123 };
2124 
2125 /// A list of operation attributes. Can be indexed by name, producing
2126 /// attributes, or by index, producing named attributes.
2127 class PyOpAttributeMap {
2128 public:
2129   PyOpAttributeMap(PyOperationRef operation)
2130       : operation(std::move(operation)) {}
2131 
2132   PyAttribute dunderGetItemNamed(const std::string &name) {
2133     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2134                                                          toMlirStringRef(name));
2135     if (mlirAttributeIsNull(attr)) {
2136       throw SetPyError(PyExc_KeyError,
2137                        "attempt to access a non-existent attribute");
2138     }
2139     return PyAttribute(operation->getContext(), attr);
2140   }
2141 
2142   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2143     if (index < 0 || index >= dunderLen()) {
2144       throw SetPyError(PyExc_IndexError,
2145                        "attempt to access out of bounds attribute");
2146     }
2147     MlirNamedAttribute namedAttr =
2148         mlirOperationGetAttribute(operation->get(), index);
2149     return PyNamedAttribute(
2150         namedAttr.attribute,
2151         std::string(mlirIdentifierStr(namedAttr.name).data,
2152                     mlirIdentifierStr(namedAttr.name).length));
2153   }
2154 
2155   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2156     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2157                                     attr);
2158   }
2159 
2160   void dunderDelItem(const std::string &name) {
2161     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2162                                                      toMlirStringRef(name));
2163     if (!removed)
2164       throw SetPyError(PyExc_KeyError,
2165                        "attempt to delete a non-existent attribute");
2166   }
2167 
2168   intptr_t dunderLen() {
2169     return mlirOperationGetNumAttributes(operation->get());
2170   }
2171 
2172   bool dunderContains(const std::string &name) {
2173     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2174         operation->get(), toMlirStringRef(name)));
2175   }
2176 
2177   static void bind(py::module &m) {
2178     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2179         .def("__contains__", &PyOpAttributeMap::dunderContains)
2180         .def("__len__", &PyOpAttributeMap::dunderLen)
2181         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2182         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2183         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2184         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2185   }
2186 
2187 private:
2188   PyOperationRef operation;
2189 };
2190 
2191 } // namespace
2192 
2193 //------------------------------------------------------------------------------
2194 // Populates the core exports of the 'ir' submodule.
2195 //------------------------------------------------------------------------------
2196 
2197 void mlir::python::populateIRCore(py::module &m) {
2198   //----------------------------------------------------------------------------
2199   // Enums.
2200   //----------------------------------------------------------------------------
2201   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2202       .value("ERROR", MlirDiagnosticError)
2203       .value("WARNING", MlirDiagnosticWarning)
2204       .value("NOTE", MlirDiagnosticNote)
2205       .value("REMARK", MlirDiagnosticRemark);
2206 
2207   //----------------------------------------------------------------------------
2208   // Mapping of Diagnostics.
2209   //----------------------------------------------------------------------------
2210   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2211       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2212       .def_property_readonly("location", &PyDiagnostic::getLocation)
2213       .def_property_readonly("message", &PyDiagnostic::getMessage)
2214       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2215       .def("__str__", [](PyDiagnostic &self) -> py::str {
2216         if (!self.isValid())
2217           return "<Invalid Diagnostic>";
2218         return self.getMessage();
2219       });
2220 
2221   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2222       .def("detach", &PyDiagnosticHandler::detach)
2223       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2224       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2225       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2226       .def("__exit__", &PyDiagnosticHandler::contextExit);
2227 
2228   //----------------------------------------------------------------------------
2229   // Mapping of MlirContext.
2230   // Note that this is exported as _BaseContext. The containing, Python level
2231   // __init__.py will subclass it with site-specific functionality and set a
2232   // "Context" attribute on this module.
2233   //----------------------------------------------------------------------------
2234   py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2235       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2236       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2237       .def("_get_context_again",
2238            [](PyMlirContext &self) {
2239              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2240              return ref.releaseObject();
2241            })
2242       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2243       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2244       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2245       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2246                              &PyMlirContext::getCapsule)
2247       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2248       .def("__enter__", &PyMlirContext::contextEnter)
2249       .def("__exit__", &PyMlirContext::contextExit)
2250       .def_property_readonly_static(
2251           "current",
2252           [](py::object & /*class*/) {
2253             auto *context = PyThreadContextEntry::getDefaultContext();
2254             if (!context)
2255               throw SetPyError(PyExc_ValueError, "No current Context");
2256             return context;
2257           },
2258           "Gets the Context bound to the current thread or raises ValueError")
2259       .def_property_readonly(
2260           "dialects",
2261           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2262           "Gets a container for accessing dialects by name")
2263       .def_property_readonly(
2264           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2265           "Alias for 'dialect'")
2266       .def(
2267           "get_dialect_descriptor",
2268           [=](PyMlirContext &self, std::string &name) {
2269             MlirDialect dialect = mlirContextGetOrLoadDialect(
2270                 self.get(), {name.data(), name.size()});
2271             if (mlirDialectIsNull(dialect)) {
2272               throw SetPyError(PyExc_ValueError,
2273                                Twine("Dialect '") + name + "' not found");
2274             }
2275             return PyDialectDescriptor(self.getRef(), dialect);
2276           },
2277           py::arg("dialect_name"),
2278           "Gets or loads a dialect by name, returning its descriptor object")
2279       .def_property(
2280           "allow_unregistered_dialects",
2281           [](PyMlirContext &self) -> bool {
2282             return mlirContextGetAllowUnregisteredDialects(self.get());
2283           },
2284           [](PyMlirContext &self, bool value) {
2285             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2286           })
2287       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2288            py::arg("callback"),
2289            "Attaches a diagnostic handler that will receive callbacks")
2290       .def(
2291           "enable_multithreading",
2292           [](PyMlirContext &self, bool enable) {
2293             mlirContextEnableMultithreading(self.get(), enable);
2294           },
2295           py::arg("enable"))
2296       .def(
2297           "is_registered_operation",
2298           [](PyMlirContext &self, std::string &name) {
2299             return mlirContextIsRegisteredOperation(
2300                 self.get(), MlirStringRef{name.data(), name.size()});
2301           },
2302           py::arg("operation_name"))
2303       .def(
2304           "append_dialect_registry",
2305           [](PyMlirContext &self, PyDialectRegistry &registry) {
2306             mlirContextAppendDialectRegistry(self.get(), registry);
2307           },
2308           py::arg("registry"))
2309       .def("load_all_available_dialects", [](PyMlirContext &self) {
2310         mlirContextLoadAllAvailableDialects(self.get());
2311       });
2312 
2313   //----------------------------------------------------------------------------
2314   // Mapping of PyDialectDescriptor
2315   //----------------------------------------------------------------------------
2316   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2317       .def_property_readonly("namespace",
2318                              [](PyDialectDescriptor &self) {
2319                                MlirStringRef ns =
2320                                    mlirDialectGetNamespace(self.get());
2321                                return py::str(ns.data, ns.length);
2322                              })
2323       .def("__repr__", [](PyDialectDescriptor &self) {
2324         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2325         std::string repr("<DialectDescriptor ");
2326         repr.append(ns.data, ns.length);
2327         repr.append(">");
2328         return repr;
2329       });
2330 
2331   //----------------------------------------------------------------------------
2332   // Mapping of PyDialects
2333   //----------------------------------------------------------------------------
2334   py::class_<PyDialects>(m, "Dialects", py::module_local())
2335       .def("__getitem__",
2336            [=](PyDialects &self, std::string keyName) {
2337              MlirDialect dialect =
2338                  self.getDialectForKey(keyName, /*attrError=*/false);
2339              py::object descriptor =
2340                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2341              return createCustomDialectWrapper(keyName, std::move(descriptor));
2342            })
2343       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2344         MlirDialect dialect =
2345             self.getDialectForKey(attrName, /*attrError=*/true);
2346         py::object descriptor =
2347             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2348         return createCustomDialectWrapper(attrName, std::move(descriptor));
2349       });
2350 
2351   //----------------------------------------------------------------------------
2352   // Mapping of PyDialect
2353   //----------------------------------------------------------------------------
2354   py::class_<PyDialect>(m, "Dialect", py::module_local())
2355       .def(py::init<py::object>(), py::arg("descriptor"))
2356       .def_property_readonly(
2357           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2358       .def("__repr__", [](py::object self) {
2359         auto clazz = self.attr("__class__");
2360         return py::str("<Dialect ") +
2361                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2362                clazz.attr("__module__") + py::str(".") +
2363                clazz.attr("__name__") + py::str(")>");
2364       });
2365 
2366   //----------------------------------------------------------------------------
2367   // Mapping of PyDialectRegistry
2368   //----------------------------------------------------------------------------
2369   py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2370       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2371                              &PyDialectRegistry::getCapsule)
2372       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2373       .def(py::init<>());
2374 
2375   //----------------------------------------------------------------------------
2376   // Mapping of Location
2377   //----------------------------------------------------------------------------
2378   py::class_<PyLocation>(m, "Location", py::module_local())
2379       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2380       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2381       .def("__enter__", &PyLocation::contextEnter)
2382       .def("__exit__", &PyLocation::contextExit)
2383       .def("__eq__",
2384            [](PyLocation &self, PyLocation &other) -> bool {
2385              return mlirLocationEqual(self, other);
2386            })
2387       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2388       .def_property_readonly_static(
2389           "current",
2390           [](py::object & /*class*/) {
2391             auto *loc = PyThreadContextEntry::getDefaultLocation();
2392             if (!loc)
2393               throw SetPyError(PyExc_ValueError, "No current Location");
2394             return loc;
2395           },
2396           "Gets the Location bound to the current thread or raises ValueError")
2397       .def_static(
2398           "unknown",
2399           [](DefaultingPyMlirContext context) {
2400             return PyLocation(context->getRef(),
2401                               mlirLocationUnknownGet(context->get()));
2402           },
2403           py::arg("context") = py::none(),
2404           "Gets a Location representing an unknown location")
2405       .def_static(
2406           "callsite",
2407           [](PyLocation callee, const std::vector<PyLocation> &frames,
2408              DefaultingPyMlirContext context) {
2409             if (frames.empty())
2410               throw py::value_error("No caller frames provided");
2411             MlirLocation caller = frames.back().get();
2412             for (const PyLocation &frame :
2413                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2414               caller = mlirLocationCallSiteGet(frame.get(), caller);
2415             return PyLocation(context->getRef(),
2416                               mlirLocationCallSiteGet(callee.get(), caller));
2417           },
2418           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2419           kContextGetCallSiteLocationDocstring)
2420       .def_static(
2421           "file",
2422           [](std::string filename, int line, int col,
2423              DefaultingPyMlirContext context) {
2424             return PyLocation(
2425                 context->getRef(),
2426                 mlirLocationFileLineColGet(
2427                     context->get(), toMlirStringRef(filename), line, col));
2428           },
2429           py::arg("filename"), py::arg("line"), py::arg("col"),
2430           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2431       .def_static(
2432           "fused",
2433           [](const std::vector<PyLocation> &pyLocations,
2434              llvm::Optional<PyAttribute> metadata,
2435              DefaultingPyMlirContext context) {
2436             llvm::SmallVector<MlirLocation, 4> locations;
2437             locations.reserve(pyLocations.size());
2438             for (auto &pyLocation : pyLocations)
2439               locations.push_back(pyLocation.get());
2440             MlirLocation location = mlirLocationFusedGet(
2441                 context->get(), locations.size(), locations.data(),
2442                 metadata ? metadata->get() : MlirAttribute{0});
2443             return PyLocation(context->getRef(), location);
2444           },
2445           py::arg("locations"), py::arg("metadata") = py::none(),
2446           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2447       .def_static(
2448           "name",
2449           [](std::string name, llvm::Optional<PyLocation> childLoc,
2450              DefaultingPyMlirContext context) {
2451             return PyLocation(
2452                 context->getRef(),
2453                 mlirLocationNameGet(
2454                     context->get(), toMlirStringRef(name),
2455                     childLoc ? childLoc->get()
2456                              : mlirLocationUnknownGet(context->get())));
2457           },
2458           py::arg("name"), py::arg("childLoc") = py::none(),
2459           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2460       .def_property_readonly(
2461           "context",
2462           [](PyLocation &self) { return self.getContext().getObject(); },
2463           "Context that owns the Location")
2464       .def(
2465           "emit_error",
2466           [](PyLocation &self, std::string message) {
2467             mlirEmitError(self, message.c_str());
2468           },
2469           py::arg("message"), "Emits an error at this location")
2470       .def("__repr__", [](PyLocation &self) {
2471         PyPrintAccumulator printAccum;
2472         mlirLocationPrint(self, printAccum.getCallback(),
2473                           printAccum.getUserData());
2474         return printAccum.join();
2475       });
2476 
2477   //----------------------------------------------------------------------------
2478   // Mapping of Module
2479   //----------------------------------------------------------------------------
2480   py::class_<PyModule>(m, "Module", py::module_local())
2481       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2482       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2483       .def_static(
2484           "parse",
2485           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2486             MlirModule module = mlirModuleCreateParse(
2487                 context->get(), toMlirStringRef(moduleAsm));
2488             // TODO: Rework error reporting once diagnostic engine is exposed
2489             // in C API.
2490             if (mlirModuleIsNull(module)) {
2491               throw SetPyError(
2492                   PyExc_ValueError,
2493                   "Unable to parse module assembly (see diagnostics)");
2494             }
2495             return PyModule::forModule(module).releaseObject();
2496           },
2497           py::arg("asm"), py::arg("context") = py::none(),
2498           kModuleParseDocstring)
2499       .def_static(
2500           "create",
2501           [](DefaultingPyLocation loc) {
2502             MlirModule module = mlirModuleCreateEmpty(loc);
2503             return PyModule::forModule(module).releaseObject();
2504           },
2505           py::arg("loc") = py::none(), "Creates an empty module")
2506       .def_property_readonly(
2507           "context",
2508           [](PyModule &self) { return self.getContext().getObject(); },
2509           "Context that created the Module")
2510       .def_property_readonly(
2511           "operation",
2512           [](PyModule &self) {
2513             return PyOperation::forOperation(self.getContext(),
2514                                              mlirModuleGetOperation(self.get()),
2515                                              self.getRef().releaseObject())
2516                 .releaseObject();
2517           },
2518           "Accesses the module as an operation")
2519       .def_property_readonly(
2520           "body",
2521           [](PyModule &self) {
2522             PyOperationRef moduleOp = PyOperation::forOperation(
2523                 self.getContext(), mlirModuleGetOperation(self.get()),
2524                 self.getRef().releaseObject());
2525             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2526             return returnBlock;
2527           },
2528           "Return the block for this module")
2529       .def(
2530           "dump",
2531           [](PyModule &self) {
2532             mlirOperationDump(mlirModuleGetOperation(self.get()));
2533           },
2534           kDumpDocstring)
2535       .def(
2536           "__str__",
2537           [](py::object self) {
2538             // Defer to the operation's __str__.
2539             return self.attr("operation").attr("__str__")();
2540           },
2541           kOperationStrDunderDocstring);
2542 
2543   //----------------------------------------------------------------------------
2544   // Mapping of Operation.
2545   //----------------------------------------------------------------------------
2546   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2547       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2548                              [](PyOperationBase &self) {
2549                                return self.getOperation().getCapsule();
2550                              })
2551       .def("__eq__",
2552            [](PyOperationBase &self, PyOperationBase &other) {
2553              return &self.getOperation() == &other.getOperation();
2554            })
2555       .def("__eq__",
2556            [](PyOperationBase &self, py::object other) { return false; })
2557       .def("__hash__",
2558            [](PyOperationBase &self) {
2559              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2560            })
2561       .def_property_readonly("attributes",
2562                              [](PyOperationBase &self) {
2563                                return PyOpAttributeMap(
2564                                    self.getOperation().getRef());
2565                              })
2566       .def_property_readonly("operands",
2567                              [](PyOperationBase &self) {
2568                                return PyOpOperandList(
2569                                    self.getOperation().getRef());
2570                              })
2571       .def_property_readonly("regions",
2572                              [](PyOperationBase &self) {
2573                                return PyRegionList(
2574                                    self.getOperation().getRef());
2575                              })
2576       .def_property_readonly(
2577           "results",
2578           [](PyOperationBase &self) {
2579             return PyOpResultList(self.getOperation().getRef());
2580           },
2581           "Returns the list of Operation results.")
2582       .def_property_readonly(
2583           "result",
2584           [](PyOperationBase &self) {
2585             auto &operation = self.getOperation();
2586             auto numResults = mlirOperationGetNumResults(operation);
2587             if (numResults != 1) {
2588               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2589               throw SetPyError(
2590                   PyExc_ValueError,
2591                   Twine("Cannot call .result on operation ") +
2592                       StringRef(name.data, name.length) + " which has " +
2593                       Twine(numResults) +
2594                       " results (it is only valid for operations with a "
2595                       "single result)");
2596             }
2597             return PyOpResult(operation.getRef(),
2598                               mlirOperationGetResult(operation, 0));
2599           },
2600           "Shortcut to get an op result if it has only one (throws an error "
2601           "otherwise).")
2602       .def_property_readonly(
2603           "location",
2604           [](PyOperationBase &self) {
2605             PyOperation &operation = self.getOperation();
2606             return PyLocation(operation.getContext(),
2607                               mlirOperationGetLocation(operation.get()));
2608           },
2609           "Returns the source location the operation was defined or derived "
2610           "from.")
2611       .def(
2612           "__str__",
2613           [](PyOperationBase &self) {
2614             return self.getAsm(/*binary=*/false,
2615                                /*largeElementsLimit=*/llvm::None,
2616                                /*enableDebugInfo=*/false,
2617                                /*prettyDebugInfo=*/false,
2618                                /*printGenericOpForm=*/false,
2619                                /*useLocalScope=*/false,
2620                                /*assumeVerified=*/false);
2621           },
2622           "Returns the assembly form of the operation.")
2623       .def("print", &PyOperationBase::print,
2624            // Careful: Lots of arguments must match up with print method.
2625            py::arg("file") = py::none(), py::arg("binary") = false,
2626            py::arg("large_elements_limit") = py::none(),
2627            py::arg("enable_debug_info") = false,
2628            py::arg("pretty_debug_info") = false,
2629            py::arg("print_generic_op_form") = false,
2630            py::arg("use_local_scope") = false,
2631            py::arg("assume_verified") = false, kOperationPrintDocstring)
2632       .def("get_asm", &PyOperationBase::getAsm,
2633            // Careful: Lots of arguments must match up with get_asm method.
2634            py::arg("binary") = false,
2635            py::arg("large_elements_limit") = py::none(),
2636            py::arg("enable_debug_info") = false,
2637            py::arg("pretty_debug_info") = false,
2638            py::arg("print_generic_op_form") = false,
2639            py::arg("use_local_scope") = false,
2640            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2641       .def(
2642           "verify",
2643           [](PyOperationBase &self) {
2644             return mlirOperationVerify(self.getOperation());
2645           },
2646           "Verify the operation and return true if it passes, false if it "
2647           "fails.")
2648       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2649            "Puts self immediately after the other operation in its parent "
2650            "block.")
2651       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2652            "Puts self immediately before the other operation in its parent "
2653            "block.")
2654       .def(
2655           "detach_from_parent",
2656           [](PyOperationBase &self) {
2657             PyOperation &operation = self.getOperation();
2658             operation.checkValid();
2659             if (!operation.isAttached())
2660               throw py::value_error("Detached operation has no parent.");
2661 
2662             operation.detachFromParent();
2663             return operation.createOpView();
2664           },
2665           "Detaches the operation from its parent block.");
2666 
2667   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2668       .def_static("create", &PyOperation::create, py::arg("name"),
2669                   py::arg("results") = py::none(),
2670                   py::arg("operands") = py::none(),
2671                   py::arg("attributes") = py::none(),
2672                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2673                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2674                   kOperationCreateDocstring)
2675       .def_property_readonly("parent",
2676                              [](PyOperation &self) -> py::object {
2677                                auto parent = self.getParentOperation();
2678                                if (parent)
2679                                  return parent->getObject();
2680                                return py::none();
2681                              })
2682       .def("erase", &PyOperation::erase)
2683       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2684       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2685                              &PyOperation::getCapsule)
2686       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2687       .def_property_readonly("name",
2688                              [](PyOperation &self) {
2689                                self.checkValid();
2690                                MlirOperation operation = self.get();
2691                                MlirStringRef name = mlirIdentifierStr(
2692                                    mlirOperationGetName(operation));
2693                                return py::str(name.data, name.length);
2694                              })
2695       .def_property_readonly(
2696           "context",
2697           [](PyOperation &self) {
2698             self.checkValid();
2699             return self.getContext().getObject();
2700           },
2701           "Context that owns the Operation")
2702       .def_property_readonly("opview", &PyOperation::createOpView);
2703 
2704   auto opViewClass =
2705       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2706           .def(py::init<py::object>(), py::arg("operation"))
2707           .def_property_readonly("operation", &PyOpView::getOperationObject)
2708           .def_property_readonly(
2709               "context",
2710               [](PyOpView &self) {
2711                 return self.getOperation().getContext().getObject();
2712               },
2713               "Context that owns the Operation")
2714           .def("__str__", [](PyOpView &self) {
2715             return py::str(self.getOperationObject());
2716           });
2717   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2718   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2719   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2720   opViewClass.attr("build_generic") = classmethod(
2721       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2722       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2723       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2724       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2725       "Builds a specific, generated OpView based on class level attributes.");
2726 
2727   //----------------------------------------------------------------------------
2728   // Mapping of PyRegion.
2729   //----------------------------------------------------------------------------
2730   py::class_<PyRegion>(m, "Region", py::module_local())
2731       .def_property_readonly(
2732           "blocks",
2733           [](PyRegion &self) {
2734             return PyBlockList(self.getParentOperation(), self.get());
2735           },
2736           "Returns a forward-optimized sequence of blocks.")
2737       .def_property_readonly(
2738           "owner",
2739           [](PyRegion &self) {
2740             return self.getParentOperation()->createOpView();
2741           },
2742           "Returns the operation owning this region.")
2743       .def(
2744           "__iter__",
2745           [](PyRegion &self) {
2746             self.checkValid();
2747             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2748             return PyBlockIterator(self.getParentOperation(), firstBlock);
2749           },
2750           "Iterates over blocks in the region.")
2751       .def("__eq__",
2752            [](PyRegion &self, PyRegion &other) {
2753              return self.get().ptr == other.get().ptr;
2754            })
2755       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2756 
2757   //----------------------------------------------------------------------------
2758   // Mapping of PyBlock.
2759   //----------------------------------------------------------------------------
2760   py::class_<PyBlock>(m, "Block", py::module_local())
2761       .def_property_readonly(
2762           "owner",
2763           [](PyBlock &self) {
2764             return self.getParentOperation()->createOpView();
2765           },
2766           "Returns the owning operation of this block.")
2767       .def_property_readonly(
2768           "region",
2769           [](PyBlock &self) {
2770             MlirRegion region = mlirBlockGetParentRegion(self.get());
2771             return PyRegion(self.getParentOperation(), region);
2772           },
2773           "Returns the owning region of this block.")
2774       .def_property_readonly(
2775           "arguments",
2776           [](PyBlock &self) {
2777             return PyBlockArgumentList(self.getParentOperation(), self.get());
2778           },
2779           "Returns a list of block arguments.")
2780       .def_property_readonly(
2781           "operations",
2782           [](PyBlock &self) {
2783             return PyOperationList(self.getParentOperation(), self.get());
2784           },
2785           "Returns a forward-optimized sequence of operations.")
2786       .def_static(
2787           "create_at_start",
2788           [](PyRegion &parent, py::list pyArgTypes) {
2789             parent.checkValid();
2790             llvm::SmallVector<MlirType, 4> argTypes;
2791             llvm::SmallVector<MlirLocation, 4> argLocs;
2792             argTypes.reserve(pyArgTypes.size());
2793             argLocs.reserve(pyArgTypes.size());
2794             for (auto &pyArg : pyArgTypes) {
2795               argTypes.push_back(pyArg.cast<PyType &>());
2796               // TODO: Pass in a proper location here.
2797               argLocs.push_back(
2798                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2799             }
2800 
2801             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2802                                               argLocs.data());
2803             mlirRegionInsertOwnedBlock(parent, 0, block);
2804             return PyBlock(parent.getParentOperation(), block);
2805           },
2806           py::arg("parent"), py::arg("arg_types") = py::list(),
2807           "Creates and returns a new Block at the beginning of the given "
2808           "region (with given argument types).")
2809       .def(
2810           "append_to",
2811           [](PyBlock &self, PyRegion &region) {
2812             MlirBlock b = self.get();
2813             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2814               mlirBlockDetach(b);
2815             mlirRegionAppendOwnedBlock(region.get(), b);
2816           },
2817           "Append this block to a region, transferring ownership if necessary")
2818       .def(
2819           "create_before",
2820           [](PyBlock &self, py::args pyArgTypes) {
2821             self.checkValid();
2822             llvm::SmallVector<MlirType, 4> argTypes;
2823             llvm::SmallVector<MlirLocation, 4> argLocs;
2824             argTypes.reserve(pyArgTypes.size());
2825             argLocs.reserve(pyArgTypes.size());
2826             for (auto &pyArg : pyArgTypes) {
2827               argTypes.push_back(pyArg.cast<PyType &>());
2828               // TODO: Pass in a proper location here.
2829               argLocs.push_back(
2830                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2831             }
2832 
2833             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2834                                               argLocs.data());
2835             MlirRegion region = mlirBlockGetParentRegion(self.get());
2836             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2837             return PyBlock(self.getParentOperation(), block);
2838           },
2839           "Creates and returns a new Block before this block "
2840           "(with given argument types).")
2841       .def(
2842           "create_after",
2843           [](PyBlock &self, py::args pyArgTypes) {
2844             self.checkValid();
2845             llvm::SmallVector<MlirType, 4> argTypes;
2846             llvm::SmallVector<MlirLocation, 4> argLocs;
2847             argTypes.reserve(pyArgTypes.size());
2848             argLocs.reserve(pyArgTypes.size());
2849             for (auto &pyArg : pyArgTypes) {
2850               argTypes.push_back(pyArg.cast<PyType &>());
2851 
2852               // TODO: Pass in a proper location here.
2853               argLocs.push_back(
2854                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2855             }
2856             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2857                                               argLocs.data());
2858             MlirRegion region = mlirBlockGetParentRegion(self.get());
2859             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2860             return PyBlock(self.getParentOperation(), block);
2861           },
2862           "Creates and returns a new Block after this block "
2863           "(with given argument types).")
2864       .def(
2865           "__iter__",
2866           [](PyBlock &self) {
2867             self.checkValid();
2868             MlirOperation firstOperation =
2869                 mlirBlockGetFirstOperation(self.get());
2870             return PyOperationIterator(self.getParentOperation(),
2871                                        firstOperation);
2872           },
2873           "Iterates over operations in the block.")
2874       .def("__eq__",
2875            [](PyBlock &self, PyBlock &other) {
2876              return self.get().ptr == other.get().ptr;
2877            })
2878       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2879       .def(
2880           "__str__",
2881           [](PyBlock &self) {
2882             self.checkValid();
2883             PyPrintAccumulator printAccum;
2884             mlirBlockPrint(self.get(), printAccum.getCallback(),
2885                            printAccum.getUserData());
2886             return printAccum.join();
2887           },
2888           "Returns the assembly form of the block.")
2889       .def(
2890           "append",
2891           [](PyBlock &self, PyOperationBase &operation) {
2892             if (operation.getOperation().isAttached())
2893               operation.getOperation().detachFromParent();
2894 
2895             MlirOperation mlirOperation = operation.getOperation().get();
2896             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2897             operation.getOperation().setAttached(
2898                 self.getParentOperation().getObject());
2899           },
2900           py::arg("operation"),
2901           "Appends an operation to this block. If the operation is currently "
2902           "in another block, it will be moved.");
2903 
2904   //----------------------------------------------------------------------------
2905   // Mapping of PyInsertionPoint.
2906   //----------------------------------------------------------------------------
2907 
2908   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2909       .def(py::init<PyBlock &>(), py::arg("block"),
2910            "Inserts after the last operation but still inside the block.")
2911       .def("__enter__", &PyInsertionPoint::contextEnter)
2912       .def("__exit__", &PyInsertionPoint::contextExit)
2913       .def_property_readonly_static(
2914           "current",
2915           [](py::object & /*class*/) {
2916             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2917             if (!ip)
2918               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2919             return ip;
2920           },
2921           "Gets the InsertionPoint bound to the current thread or raises "
2922           "ValueError if none has been set")
2923       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2924            "Inserts before a referenced operation.")
2925       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2926                   py::arg("block"), "Inserts at the beginning of the block.")
2927       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2928                   py::arg("block"), "Inserts before the block terminator.")
2929       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2930            "Inserts an operation.")
2931       .def_property_readonly(
2932           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2933           "Returns the block that this InsertionPoint points to.");
2934 
2935   //----------------------------------------------------------------------------
2936   // Mapping of PyAttribute.
2937   //----------------------------------------------------------------------------
2938   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2939       // Delegate to the PyAttribute copy constructor, which will also lifetime
2940       // extend the backing context which owns the MlirAttribute.
2941       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2942            "Casts the passed attribute to the generic Attribute")
2943       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2944                              &PyAttribute::getCapsule)
2945       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2946       .def_static(
2947           "parse",
2948           [](std::string attrSpec, DefaultingPyMlirContext context) {
2949             MlirAttribute type = mlirAttributeParseGet(
2950                 context->get(), toMlirStringRef(attrSpec));
2951             // TODO: Rework error reporting once diagnostic engine is exposed
2952             // in C API.
2953             if (mlirAttributeIsNull(type)) {
2954               throw SetPyError(PyExc_ValueError,
2955                                Twine("Unable to parse attribute: '") +
2956                                    attrSpec + "'");
2957             }
2958             return PyAttribute(context->getRef(), type);
2959           },
2960           py::arg("asm"), py::arg("context") = py::none(),
2961           "Parses an attribute from an assembly form")
2962       .def_property_readonly(
2963           "context",
2964           [](PyAttribute &self) { return self.getContext().getObject(); },
2965           "Context that owns the Attribute")
2966       .def_property_readonly("type",
2967                              [](PyAttribute &self) {
2968                                return PyType(self.getContext()->getRef(),
2969                                              mlirAttributeGetType(self));
2970                              })
2971       .def(
2972           "get_named",
2973           [](PyAttribute &self, std::string name) {
2974             return PyNamedAttribute(self, std::move(name));
2975           },
2976           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2977       .def("__eq__",
2978            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2979       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2980       .def("__hash__",
2981            [](PyAttribute &self) {
2982              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2983            })
2984       .def(
2985           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2986           kDumpDocstring)
2987       .def(
2988           "__str__",
2989           [](PyAttribute &self) {
2990             PyPrintAccumulator printAccum;
2991             mlirAttributePrint(self, printAccum.getCallback(),
2992                                printAccum.getUserData());
2993             return printAccum.join();
2994           },
2995           "Returns the assembly form of the Attribute.")
2996       .def("__repr__", [](PyAttribute &self) {
2997         // Generally, assembly formats are not printed for __repr__ because
2998         // this can cause exceptionally long debug output and exceptions.
2999         // However, attribute values are generally considered useful and are
3000         // printed. This may need to be re-evaluated if debug dumps end up
3001         // being excessive.
3002         PyPrintAccumulator printAccum;
3003         printAccum.parts.append("Attribute(");
3004         mlirAttributePrint(self, printAccum.getCallback(),
3005                            printAccum.getUserData());
3006         printAccum.parts.append(")");
3007         return printAccum.join();
3008       });
3009 
3010   //----------------------------------------------------------------------------
3011   // Mapping of PyNamedAttribute
3012   //----------------------------------------------------------------------------
3013   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3014       .def("__repr__",
3015            [](PyNamedAttribute &self) {
3016              PyPrintAccumulator printAccum;
3017              printAccum.parts.append("NamedAttribute(");
3018              printAccum.parts.append(
3019                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3020                          mlirIdentifierStr(self.namedAttr.name).length));
3021              printAccum.parts.append("=");
3022              mlirAttributePrint(self.namedAttr.attribute,
3023                                 printAccum.getCallback(),
3024                                 printAccum.getUserData());
3025              printAccum.parts.append(")");
3026              return printAccum.join();
3027            })
3028       .def_property_readonly(
3029           "name",
3030           [](PyNamedAttribute &self) {
3031             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3032                            mlirIdentifierStr(self.namedAttr.name).length);
3033           },
3034           "The name of the NamedAttribute binding")
3035       .def_property_readonly(
3036           "attr",
3037           [](PyNamedAttribute &self) {
3038             // TODO: When named attribute is removed/refactored, also remove
3039             // this constructor (it does an inefficient table lookup).
3040             auto contextRef = PyMlirContext::forContext(
3041                 mlirAttributeGetContext(self.namedAttr.attribute));
3042             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3043           },
3044           py::keep_alive<0, 1>(),
3045           "The underlying generic attribute of the NamedAttribute binding");
3046 
3047   //----------------------------------------------------------------------------
3048   // Mapping of PyType.
3049   //----------------------------------------------------------------------------
3050   py::class_<PyType>(m, "Type", py::module_local())
3051       // Delegate to the PyType copy constructor, which will also lifetime
3052       // extend the backing context which owns the MlirType.
3053       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3054            "Casts the passed type to the generic Type")
3055       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3056       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3057       .def_static(
3058           "parse",
3059           [](std::string typeSpec, DefaultingPyMlirContext context) {
3060             MlirType type =
3061                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3062             // TODO: Rework error reporting once diagnostic engine is exposed
3063             // in C API.
3064             if (mlirTypeIsNull(type)) {
3065               throw SetPyError(PyExc_ValueError,
3066                                Twine("Unable to parse type: '") + typeSpec +
3067                                    "'");
3068             }
3069             return PyType(context->getRef(), type);
3070           },
3071           py::arg("asm"), py::arg("context") = py::none(),
3072           kContextParseTypeDocstring)
3073       .def_property_readonly(
3074           "context", [](PyType &self) { return self.getContext().getObject(); },
3075           "Context that owns the Type")
3076       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3077       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3078       .def("__hash__",
3079            [](PyType &self) {
3080              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3081            })
3082       .def(
3083           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3084       .def(
3085           "__str__",
3086           [](PyType &self) {
3087             PyPrintAccumulator printAccum;
3088             mlirTypePrint(self, printAccum.getCallback(),
3089                           printAccum.getUserData());
3090             return printAccum.join();
3091           },
3092           "Returns the assembly form of the type.")
3093       .def("__repr__", [](PyType &self) {
3094         // Generally, assembly formats are not printed for __repr__ because
3095         // this can cause exceptionally long debug output and exceptions.
3096         // However, types are an exception as they typically have compact
3097         // assembly forms and printing them is useful.
3098         PyPrintAccumulator printAccum;
3099         printAccum.parts.append("Type(");
3100         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3101         printAccum.parts.append(")");
3102         return printAccum.join();
3103       });
3104 
3105   //----------------------------------------------------------------------------
3106   // Mapping of Value.
3107   //----------------------------------------------------------------------------
3108   py::class_<PyValue>(m, "Value", py::module_local())
3109       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3110       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3111       .def_property_readonly(
3112           "context",
3113           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3114           "Context in which the value lives.")
3115       .def(
3116           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3117           kDumpDocstring)
3118       .def_property_readonly(
3119           "owner",
3120           [](PyValue &self) {
3121             assert(mlirOperationEqual(self.getParentOperation()->get(),
3122                                       mlirOpResultGetOwner(self.get())) &&
3123                    "expected the owner of the value in Python to match that in "
3124                    "the IR");
3125             return self.getParentOperation().getObject();
3126           })
3127       .def("__eq__",
3128            [](PyValue &self, PyValue &other) {
3129              return self.get().ptr == other.get().ptr;
3130            })
3131       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3132       .def("__hash__",
3133            [](PyValue &self) {
3134              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3135            })
3136       .def(
3137           "__str__",
3138           [](PyValue &self) {
3139             PyPrintAccumulator printAccum;
3140             printAccum.parts.append("Value(");
3141             mlirValuePrint(self.get(), printAccum.getCallback(),
3142                            printAccum.getUserData());
3143             printAccum.parts.append(")");
3144             return printAccum.join();
3145           },
3146           kValueDunderStrDocstring)
3147       .def_property_readonly("type", [](PyValue &self) {
3148         return PyType(self.getParentOperation()->getContext(),
3149                       mlirValueGetType(self.get()));
3150       });
3151   PyBlockArgument::bind(m);
3152   PyOpResult::bind(m);
3153 
3154   //----------------------------------------------------------------------------
3155   // Mapping of SymbolTable.
3156   //----------------------------------------------------------------------------
3157   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3158       .def(py::init<PyOperationBase &>())
3159       .def("__getitem__", &PySymbolTable::dunderGetItem)
3160       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3161       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3162       .def("__delitem__", &PySymbolTable::dunderDel)
3163       .def("__contains__",
3164            [](PySymbolTable &table, const std::string &name) {
3165              return !mlirOperationIsNull(mlirSymbolTableLookup(
3166                  table, mlirStringRefCreate(name.data(), name.length())));
3167            })
3168       // Static helpers.
3169       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3170                   py::arg("symbol"), py::arg("name"))
3171       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3172                   py::arg("symbol"))
3173       .def_static("get_visibility", &PySymbolTable::getVisibility,
3174                   py::arg("symbol"))
3175       .def_static("set_visibility", &PySymbolTable::setVisibility,
3176                   py::arg("symbol"), py::arg("visibility"))
3177       .def_static("replace_all_symbol_uses",
3178                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3179                   py::arg("new_symbol"), py::arg("from_op"))
3180       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3181                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3182                   py::arg("callback"));
3183 
3184   // Container bindings.
3185   PyBlockArgumentList::bind(m);
3186   PyBlockIterator::bind(m);
3187   PyBlockList::bind(m);
3188   PyOperationIterator::bind(m);
3189   PyOperationList::bind(m);
3190   PyOpAttributeMap::bind(m);
3191   PyOpOperandList::bind(m);
3192   PyOpResultList::bind(m);
3193   PyRegionIterator::bind(m);
3194   PyRegionList::bind(m);
3195 
3196   // Debug bindings.
3197   PyGlobalDebugFlag::bind(m);
3198 }
3199