1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 #include <utility>
24 
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32 
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36 
37 static const char kContextParseTypeDocstring[] =
38     R"(Parses the assembly form of a type.
39 
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41 
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44 
45 static const char kContextGetCallSiteLocationDocstring[] =
46     R"(Gets a Location representing a caller and callsite)";
47 
48 static const char kContextGetFileLocationDocstring[] =
49     R"(Gets a Location representing a file, line and column)";
50 
51 static const char kContextGetFusedLocationDocstring[] =
52     R"(Gets a Location representing a fused location with optional metadata)";
53 
54 static const char kContextGetNameLocationDocString[] =
55     R"(Gets a Location representing a named location with optional child location)";
56 
57 static const char kModuleParseDocstring[] =
58     R"(Parses a module's assembly format from a string.
59 
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61 
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64 
65 static const char kOperationCreateDocstring[] =
66     R"(Creates a new operation.
67 
68 Args:
69   name: Operation name (e.g. "dialect.operation").
70   results: Sequence of Type representing op result types.
71   attributes: Dict of str:Attribute.
72   successors: List of Block for the operation's successors.
73   regions: Number of regions to create.
74   location: A Location object (defaults to resolve from context manager).
75   ip: An InsertionPoint (defaults to resolve from context manager or set to
76     False to disable insertion, even with an insertion point set in the
77     context manager).
78 Returns:
79   A new "detached" Operation object. Detached operations can be added
80   to blocks, which causes them to become "attached."
81 )";
82 
83 static const char kOperationPrintDocstring[] =
84     R"(Prints the assembly form of the operation to a file like object.
85 
86 Args:
87   file: The file like object to write to. Defaults to sys.stdout.
88   binary: Whether to write bytes (True) or str (False). Defaults to False.
89   large_elements_limit: Whether to elide elements attributes above this
90     number of elements. Defaults to None (no limit).
91   enable_debug_info: Whether to print debug/location information. Defaults
92     to False.
93   pretty_debug_info: Whether to format debug information for easier reading
94     by a human (warning: the result is unparseable).
95   print_generic_op_form: Whether to print the generic assembly forms of all
96     ops. Defaults to False.
97   use_local_Scope: Whether to print in a way that is more optimized for
98     multi-threaded access but may not be consistent with how the overall
99     module prints.
100   assume_verified: By default, if not printing generic form, the verifier
101     will be run and if it fails, generic form will be printed with a comment
102     about failed verification. While a reasonable default for interactive use,
103     for systematic use, it is often better for the caller to verify explicitly
104     and report failures in a more robust fashion. Set this to True if doing this
105     in order to avoid running a redundant verification. If the IR is actually
106     invalid, behavior is undefined.
107 )";
108 
109 static const char kOperationGetAsmDocstring[] =
110     R"(Gets the assembly form of the operation with all options available.
111 
112 Args:
113   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114     False.
115   ... others ...: See the print() method for common keyword arguments for
116     configuring the printout.
117 Returns:
118   Either a bytes or str object, depending on the setting of the 'binary'
119   argument.
120 )";
121 
122 static const char kOperationStrDunderDocstring[] =
123     R"(Gets the assembly form of the operation with default options.
124 
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129 
130 static const char kDumpDocstring[] =
131     R"(Dumps a debug representation of the object to stderr.)";
132 
133 static const char kAppendBlockDocstring[] =
134     R"(Appends a new block, with argument types as positional args.
135 
136 Returns:
137   The created block.
138 )";
139 
140 static const char kValueDunderStrDocstring[] =
141     R"(Returns the string form of the value.
142 
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147 
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151 
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
154 py::object classmethod(Func f, Args... args) {
155   py::object cf = py::cpp_function(f, args...);
156   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158 
159 static py::object
160 createCustomDialectWrapper(const std::string &dialectNamespace,
161                            py::object dialectDescriptor) {
162   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163   if (!dialectClass) {
164     // Use the base class.
165     return py::cast(PyDialect(std::move(dialectDescriptor)));
166   }
167 
168   // Create the custom implementation.
169   return (*dialectClass)(std::move(dialectDescriptor));
170 }
171 
172 static MlirStringRef toMlirStringRef(const std::string &s) {
173   return mlirStringRefCreate(s.data(), s.size());
174 }
175 
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
178   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179 
180   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181 
182   static void bind(py::module &m) {
183     // Debug flags.
184     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185         .def_property_static("flag", &PyGlobalDebugFlag::get,
186                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187   }
188 };
189 
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193 
194 namespace {
195 
196 class PyRegionIterator {
197 public:
198   PyRegionIterator(PyOperationRef operation)
199       : operation(std::move(operation)) {}
200 
201   PyRegionIterator &dunderIter() { return *this; }
202 
203   PyRegion dunderNext() {
204     operation->checkValid();
205     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206       throw py::stop_iteration();
207     }
208     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209     return PyRegion(operation, region);
210   }
211 
212   static void bind(py::module &m) {
213     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214         .def("__iter__", &PyRegionIterator::dunderIter)
215         .def("__next__", &PyRegionIterator::dunderNext);
216   }
217 
218 private:
219   PyOperationRef operation;
220   int nextIndex = 0;
221 };
222 
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
227   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228 
229   intptr_t dunderLen() {
230     operation->checkValid();
231     return mlirOperationGetNumRegions(operation->get());
232   }
233 
234   PyRegion dunderGetItem(intptr_t index) {
235     // dunderLen checks validity.
236     if (index < 0 || index >= dunderLen()) {
237       throw SetPyError(PyExc_IndexError,
238                        "attempt to access out of bounds region");
239     }
240     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241     return PyRegion(operation, region);
242   }
243 
244   static void bind(py::module &m) {
245     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246         .def("__len__", &PyRegionList::dunderLen)
247         .def("__getitem__", &PyRegionList::dunderGetItem);
248   }
249 
250 private:
251   PyOperationRef operation;
252 };
253 
254 class PyBlockIterator {
255 public:
256   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257       : operation(std::move(operation)), next(next) {}
258 
259   PyBlockIterator &dunderIter() { return *this; }
260 
261   PyBlock dunderNext() {
262     operation->checkValid();
263     if (mlirBlockIsNull(next)) {
264       throw py::stop_iteration();
265     }
266 
267     PyBlock returnBlock(operation, next);
268     next = mlirBlockGetNextInRegion(next);
269     return returnBlock;
270   }
271 
272   static void bind(py::module &m) {
273     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274         .def("__iter__", &PyBlockIterator::dunderIter)
275         .def("__next__", &PyBlockIterator::dunderNext);
276   }
277 
278 private:
279   PyOperationRef operation;
280   MlirBlock next;
281 };
282 
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
288   PyBlockList(PyOperationRef operation, MlirRegion region)
289       : operation(std::move(operation)), region(region) {}
290 
291   PyBlockIterator dunderIter() {
292     operation->checkValid();
293     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294   }
295 
296   intptr_t dunderLen() {
297     operation->checkValid();
298     intptr_t count = 0;
299     MlirBlock block = mlirRegionGetFirstBlock(region);
300     while (!mlirBlockIsNull(block)) {
301       count += 1;
302       block = mlirBlockGetNextInRegion(block);
303     }
304     return count;
305   }
306 
307   PyBlock dunderGetItem(intptr_t index) {
308     operation->checkValid();
309     if (index < 0) {
310       throw SetPyError(PyExc_IndexError,
311                        "attempt to access out of bounds block");
312     }
313     MlirBlock block = mlirRegionGetFirstBlock(region);
314     while (!mlirBlockIsNull(block)) {
315       if (index == 0) {
316         return PyBlock(operation, block);
317       }
318       block = mlirBlockGetNextInRegion(block);
319       index -= 1;
320     }
321     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322   }
323 
324   PyBlock appendBlock(const py::args &pyArgTypes) {
325     operation->checkValid();
326     llvm::SmallVector<MlirType, 4> argTypes;
327     llvm::SmallVector<MlirLocation, 4> argLocs;
328     argTypes.reserve(pyArgTypes.size());
329     argLocs.reserve(pyArgTypes.size());
330     for (auto &pyArg : pyArgTypes) {
331       argTypes.push_back(pyArg.cast<PyType &>());
332       // TODO: Pass in a proper location here.
333       argLocs.push_back(
334           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335     }
336 
337     MlirBlock block =
338         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339     mlirRegionAppendOwnedBlock(region, block);
340     return PyBlock(operation, block);
341   }
342 
343   static void bind(py::module &m) {
344     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345         .def("__getitem__", &PyBlockList::dunderGetItem)
346         .def("__iter__", &PyBlockList::dunderIter)
347         .def("__len__", &PyBlockList::dunderLen)
348         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349   }
350 
351 private:
352   PyOperationRef operation;
353   MlirRegion region;
354 };
355 
356 class PyOperationIterator {
357 public:
358   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359       : parentOperation(std::move(parentOperation)), next(next) {}
360 
361   PyOperationIterator &dunderIter() { return *this; }
362 
363   py::object dunderNext() {
364     parentOperation->checkValid();
365     if (mlirOperationIsNull(next)) {
366       throw py::stop_iteration();
367     }
368 
369     PyOperationRef returnOperation =
370         PyOperation::forOperation(parentOperation->getContext(), next);
371     next = mlirOperationGetNextInBlock(next);
372     return returnOperation->createOpView();
373   }
374 
375   static void bind(py::module &m) {
376     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377         .def("__iter__", &PyOperationIterator::dunderIter)
378         .def("__next__", &PyOperationIterator::dunderNext);
379   }
380 
381 private:
382   PyOperationRef parentOperation;
383   MlirOperation next;
384 };
385 
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
392   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393       : parentOperation(std::move(parentOperation)), block(block) {}
394 
395   PyOperationIterator dunderIter() {
396     parentOperation->checkValid();
397     return PyOperationIterator(parentOperation,
398                                mlirBlockGetFirstOperation(block));
399   }
400 
401   intptr_t dunderLen() {
402     parentOperation->checkValid();
403     intptr_t count = 0;
404     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405     while (!mlirOperationIsNull(childOp)) {
406       count += 1;
407       childOp = mlirOperationGetNextInBlock(childOp);
408     }
409     return count;
410   }
411 
412   py::object dunderGetItem(intptr_t index) {
413     parentOperation->checkValid();
414     if (index < 0) {
415       throw SetPyError(PyExc_IndexError,
416                        "attempt to access out of bounds operation");
417     }
418     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419     while (!mlirOperationIsNull(childOp)) {
420       if (index == 0) {
421         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422             ->createOpView();
423       }
424       childOp = mlirOperationGetNextInBlock(childOp);
425       index -= 1;
426     }
427     throw SetPyError(PyExc_IndexError,
428                      "attempt to access out of bounds operation");
429   }
430 
431   static void bind(py::module &m) {
432     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433         .def("__getitem__", &PyOperationList::dunderGetItem)
434         .def("__iter__", &PyOperationList::dunderIter)
435         .def("__len__", &PyOperationList::dunderLen);
436   }
437 
438 private:
439   PyOperationRef parentOperation;
440   MlirBlock block;
441 };
442 
443 } // namespace
444 
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448 
449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450   py::gil_scoped_acquire acquire;
451   auto &liveContexts = getLiveContexts();
452   liveContexts[context.ptr] = this;
453 }
454 
455 PyMlirContext::~PyMlirContext() {
456   // Note that the only public way to construct an instance is via the
457   // forContext method, which always puts the associated handle into
458   // liveContexts.
459   py::gil_scoped_acquire acquire;
460   getLiveContexts().erase(context.ptr);
461   mlirContextDestroy(context);
462 }
463 
464 py::object PyMlirContext::getCapsule() {
465   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467 
468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470   if (mlirContextIsNull(rawContext))
471     throw py::error_already_set();
472   return forContext(rawContext).releaseObject();
473 }
474 
475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476   MlirContext context = mlirContextCreate();
477   mlirRegisterAllDialects(context);
478   return new PyMlirContext(context);
479 }
480 
481 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
482   py::gil_scoped_acquire acquire;
483   auto &liveContexts = getLiveContexts();
484   auto it = liveContexts.find(context.ptr);
485   if (it == liveContexts.end()) {
486     // Create.
487     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
488     py::object pyRef = py::cast(unownedContextWrapper);
489     assert(pyRef && "cast to py::object failed");
490     liveContexts[context.ptr] = unownedContextWrapper;
491     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
492   }
493   // Use existing.
494   py::object pyRef = py::cast(it->second);
495   return PyMlirContextRef(it->second, std::move(pyRef));
496 }
497 
498 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
499   static LiveContextMap liveContexts;
500   return liveContexts;
501 }
502 
503 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
504 
505 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
506 
507 size_t PyMlirContext::clearLiveOperations() {
508   for (auto &op : liveOperations)
509     op.second.second->setInvalid();
510   size_t numInvalidated = liveOperations.size();
511   liveOperations.clear();
512   return numInvalidated;
513 }
514 
515 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
516 
517 pybind11::object PyMlirContext::contextEnter() {
518   return PyThreadContextEntry::pushContext(*this);
519 }
520 
521 void PyMlirContext::contextExit(const pybind11::object &excType,
522                                 const pybind11::object &excVal,
523                                 const pybind11::object &excTb) {
524   PyThreadContextEntry::popContext(*this);
525 }
526 
527 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
528   // Note that ownership is transferred to the delete callback below by way of
529   // an explicit inc_ref (borrow).
530   PyDiagnosticHandler *pyHandler =
531       new PyDiagnosticHandler(get(), std::move(callback));
532   py::object pyHandlerObject =
533       py::cast(pyHandler, py::return_value_policy::take_ownership);
534   pyHandlerObject.inc_ref();
535 
536   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
537   // guaranteed to be known to pybind.
538   auto handlerCallback =
539       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
540     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
541     py::object pyDiagnosticObject =
542         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
543 
544     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
545     bool result = false;
546     {
547       // Since this can be called from arbitrary C++ contexts, always get the
548       // gil.
549       py::gil_scoped_acquire gil;
550       try {
551         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
552       } catch (std::exception &e) {
553         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
554                 e.what());
555         pyHandler->hadError = true;
556       }
557     }
558 
559     pyDiagnostic->invalidate();
560     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
561   };
562   auto deleteCallback = +[](void *userData) {
563     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
564     assert(pyHandler->registeredID && "handler is not registered");
565     pyHandler->registeredID.reset();
566 
567     // Decrement reference, balancing the inc_ref() above.
568     py::object pyHandlerObject =
569         py::cast(pyHandler, py::return_value_policy::reference);
570     pyHandlerObject.dec_ref();
571   };
572 
573   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
574       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
575   return pyHandlerObject;
576 }
577 
578 PyMlirContext &DefaultingPyMlirContext::resolve() {
579   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
580   if (!context) {
581     throw SetPyError(
582         PyExc_RuntimeError,
583         "An MLIR function requires a Context but none was provided in the call "
584         "or from the surrounding environment. Either pass to the function with "
585         "a 'context=' argument or establish a default using 'with Context():'");
586   }
587   return *context;
588 }
589 
590 //------------------------------------------------------------------------------
591 // PyThreadContextEntry management
592 //------------------------------------------------------------------------------
593 
594 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
595   static thread_local std::vector<PyThreadContextEntry> stack;
596   return stack;
597 }
598 
599 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
600   auto &stack = getStack();
601   if (stack.empty())
602     return nullptr;
603   return &stack.back();
604 }
605 
606 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
607                                 py::object insertionPoint,
608                                 py::object location) {
609   auto &stack = getStack();
610   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
611                      std::move(location));
612   // If the new stack has more than one entry and the context of the new top
613   // entry matches the previous, copy the insertionPoint and location from the
614   // previous entry if missing from the new top entry.
615   if (stack.size() > 1) {
616     auto &prev = *(stack.rbegin() + 1);
617     auto &current = stack.back();
618     if (current.context.is(prev.context)) {
619       // Default non-context objects from the previous entry.
620       if (!current.insertionPoint)
621         current.insertionPoint = prev.insertionPoint;
622       if (!current.location)
623         current.location = prev.location;
624     }
625   }
626 }
627 
628 PyMlirContext *PyThreadContextEntry::getContext() {
629   if (!context)
630     return nullptr;
631   return py::cast<PyMlirContext *>(context);
632 }
633 
634 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
635   if (!insertionPoint)
636     return nullptr;
637   return py::cast<PyInsertionPoint *>(insertionPoint);
638 }
639 
640 PyLocation *PyThreadContextEntry::getLocation() {
641   if (!location)
642     return nullptr;
643   return py::cast<PyLocation *>(location);
644 }
645 
646 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
647   auto *tos = getTopOfStack();
648   return tos ? tos->getContext() : nullptr;
649 }
650 
651 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
652   auto *tos = getTopOfStack();
653   return tos ? tos->getInsertionPoint() : nullptr;
654 }
655 
656 PyLocation *PyThreadContextEntry::getDefaultLocation() {
657   auto *tos = getTopOfStack();
658   return tos ? tos->getLocation() : nullptr;
659 }
660 
661 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
662   py::object contextObj = py::cast(context);
663   push(FrameKind::Context, /*context=*/contextObj,
664        /*insertionPoint=*/py::object(),
665        /*location=*/py::object());
666   return contextObj;
667 }
668 
669 void PyThreadContextEntry::popContext(PyMlirContext &context) {
670   auto &stack = getStack();
671   if (stack.empty())
672     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
673   auto &tos = stack.back();
674   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
675     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
676   stack.pop_back();
677 }
678 
679 py::object
680 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
681   py::object contextObj =
682       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
683   py::object insertionPointObj = py::cast(insertionPoint);
684   push(FrameKind::InsertionPoint,
685        /*context=*/contextObj,
686        /*insertionPoint=*/insertionPointObj,
687        /*location=*/py::object());
688   return insertionPointObj;
689 }
690 
691 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
692   auto &stack = getStack();
693   if (stack.empty())
694     throw SetPyError(PyExc_RuntimeError,
695                      "Unbalanced InsertionPoint enter/exit");
696   auto &tos = stack.back();
697   if (tos.frameKind != FrameKind::InsertionPoint &&
698       tos.getInsertionPoint() != &insertionPoint)
699     throw SetPyError(PyExc_RuntimeError,
700                      "Unbalanced InsertionPoint enter/exit");
701   stack.pop_back();
702 }
703 
704 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
705   py::object contextObj = location.getContext().getObject();
706   py::object locationObj = py::cast(location);
707   push(FrameKind::Location, /*context=*/contextObj,
708        /*insertionPoint=*/py::object(),
709        /*location=*/locationObj);
710   return locationObj;
711 }
712 
713 void PyThreadContextEntry::popLocation(PyLocation &location) {
714   auto &stack = getStack();
715   if (stack.empty())
716     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
717   auto &tos = stack.back();
718   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
719     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
720   stack.pop_back();
721 }
722 
723 //------------------------------------------------------------------------------
724 // PyDiagnostic*
725 //------------------------------------------------------------------------------
726 
727 void PyDiagnostic::invalidate() {
728   valid = false;
729   if (materializedNotes) {
730     for (auto &noteObject : *materializedNotes) {
731       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
732       note->invalidate();
733     }
734   }
735 }
736 
737 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
738                                          py::object callback)
739     : context(context), callback(std::move(callback)) {}
740 
741 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
742 
743 void PyDiagnosticHandler::detach() {
744   if (!registeredID)
745     return;
746   MlirDiagnosticHandlerID localID = *registeredID;
747   mlirContextDetachDiagnosticHandler(context, localID);
748   assert(!registeredID && "should have unregistered");
749   // Not strictly necessary but keeps stale pointers from being around to cause
750   // issues.
751   context = {nullptr};
752 }
753 
754 void PyDiagnostic::checkValid() {
755   if (!valid) {
756     throw std::invalid_argument(
757         "Diagnostic is invalid (used outside of callback)");
758   }
759 }
760 
761 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
762   checkValid();
763   return mlirDiagnosticGetSeverity(diagnostic);
764 }
765 
766 PyLocation PyDiagnostic::getLocation() {
767   checkValid();
768   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
769   MlirContext context = mlirLocationGetContext(loc);
770   return PyLocation(PyMlirContext::forContext(context), loc);
771 }
772 
773 py::str PyDiagnostic::getMessage() {
774   checkValid();
775   py::object fileObject = py::module::import("io").attr("StringIO")();
776   PyFileAccumulator accum(fileObject, /*binary=*/false);
777   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
778   return fileObject.attr("getvalue")();
779 }
780 
781 py::tuple PyDiagnostic::getNotes() {
782   checkValid();
783   if (materializedNotes)
784     return *materializedNotes;
785   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
786   materializedNotes = py::tuple(numNotes);
787   for (intptr_t i = 0; i < numNotes; ++i) {
788     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
789     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
790     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
791   }
792   return *materializedNotes;
793 }
794 
795 //------------------------------------------------------------------------------
796 // PyDialect, PyDialectDescriptor, PyDialects
797 //------------------------------------------------------------------------------
798 
799 MlirDialect PyDialects::getDialectForKey(const std::string &key,
800                                          bool attrError) {
801   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
802                                                     {key.data(), key.size()});
803   if (mlirDialectIsNull(dialect)) {
804     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
805                      Twine("Dialect '") + key + "' not found");
806   }
807   return dialect;
808 }
809 
810 //------------------------------------------------------------------------------
811 // PyLocation
812 //------------------------------------------------------------------------------
813 
814 py::object PyLocation::getCapsule() {
815   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
816 }
817 
818 PyLocation PyLocation::createFromCapsule(py::object capsule) {
819   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
820   if (mlirLocationIsNull(rawLoc))
821     throw py::error_already_set();
822   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
823                     rawLoc);
824 }
825 
826 py::object PyLocation::contextEnter() {
827   return PyThreadContextEntry::pushLocation(*this);
828 }
829 
830 void PyLocation::contextExit(const pybind11::object &excType,
831                              const pybind11::object &excVal,
832                              const pybind11::object &excTb) {
833   PyThreadContextEntry::popLocation(*this);
834 }
835 
836 PyLocation &DefaultingPyLocation::resolve() {
837   auto *location = PyThreadContextEntry::getDefaultLocation();
838   if (!location) {
839     throw SetPyError(
840         PyExc_RuntimeError,
841         "An MLIR function requires a Location but none was provided in the "
842         "call or from the surrounding environment. Either pass to the function "
843         "with a 'loc=' argument or establish a default using 'with loc:'");
844   }
845   return *location;
846 }
847 
848 //------------------------------------------------------------------------------
849 // PyModule
850 //------------------------------------------------------------------------------
851 
852 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
853     : BaseContextObject(std::move(contextRef)), module(module) {}
854 
855 PyModule::~PyModule() {
856   py::gil_scoped_acquire acquire;
857   auto &liveModules = getContext()->liveModules;
858   assert(liveModules.count(module.ptr) == 1 &&
859          "destroying module not in live map");
860   liveModules.erase(module.ptr);
861   mlirModuleDestroy(module);
862 }
863 
864 PyModuleRef PyModule::forModule(MlirModule module) {
865   MlirContext context = mlirModuleGetContext(module);
866   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
867 
868   py::gil_scoped_acquire acquire;
869   auto &liveModules = contextRef->liveModules;
870   auto it = liveModules.find(module.ptr);
871   if (it == liveModules.end()) {
872     // Create.
873     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
874     // Note that the default return value policy on cast is automatic_reference,
875     // which does not take ownership (delete will not be called).
876     // Just be explicit.
877     py::object pyRef =
878         py::cast(unownedModule, py::return_value_policy::take_ownership);
879     unownedModule->handle = pyRef;
880     liveModules[module.ptr] =
881         std::make_pair(unownedModule->handle, unownedModule);
882     return PyModuleRef(unownedModule, std::move(pyRef));
883   }
884   // Use existing.
885   PyModule *existing = it->second.second;
886   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
887   return PyModuleRef(existing, std::move(pyRef));
888 }
889 
890 py::object PyModule::createFromCapsule(py::object capsule) {
891   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
892   if (mlirModuleIsNull(rawModule))
893     throw py::error_already_set();
894   return forModule(rawModule).releaseObject();
895 }
896 
897 py::object PyModule::getCapsule() {
898   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
899 }
900 
901 //------------------------------------------------------------------------------
902 // PyOperation
903 //------------------------------------------------------------------------------
904 
905 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
906     : BaseContextObject(std::move(contextRef)), operation(operation) {}
907 
908 PyOperation::~PyOperation() {
909   // If the operation has already been invalidated there is nothing to do.
910   if (!valid)
911     return;
912   auto &liveOperations = getContext()->liveOperations;
913   assert(liveOperations.count(operation.ptr) == 1 &&
914          "destroying operation not in live map");
915   liveOperations.erase(operation.ptr);
916   if (!isAttached()) {
917     mlirOperationDestroy(operation);
918   }
919 }
920 
921 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
922                                            MlirOperation operation,
923                                            py::object parentKeepAlive) {
924   auto &liveOperations = contextRef->liveOperations;
925   // Create.
926   PyOperation *unownedOperation =
927       new PyOperation(std::move(contextRef), operation);
928   // Note that the default return value policy on cast is automatic_reference,
929   // which does not take ownership (delete will not be called).
930   // Just be explicit.
931   py::object pyRef =
932       py::cast(unownedOperation, py::return_value_policy::take_ownership);
933   unownedOperation->handle = pyRef;
934   if (parentKeepAlive) {
935     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
936   }
937   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
938   return PyOperationRef(unownedOperation, std::move(pyRef));
939 }
940 
941 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
942                                          MlirOperation operation,
943                                          py::object parentKeepAlive) {
944   auto &liveOperations = contextRef->liveOperations;
945   auto it = liveOperations.find(operation.ptr);
946   if (it == liveOperations.end()) {
947     // Create.
948     return createInstance(std::move(contextRef), operation,
949                           std::move(parentKeepAlive));
950   }
951   // Use existing.
952   PyOperation *existing = it->second.second;
953   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
954   return PyOperationRef(existing, std::move(pyRef));
955 }
956 
957 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
958                                            MlirOperation operation,
959                                            py::object parentKeepAlive) {
960   auto &liveOperations = contextRef->liveOperations;
961   assert(liveOperations.count(operation.ptr) == 0 &&
962          "cannot create detached operation that already exists");
963   (void)liveOperations;
964 
965   PyOperationRef created = createInstance(std::move(contextRef), operation,
966                                           std::move(parentKeepAlive));
967   created->attached = false;
968   return created;
969 }
970 
971 void PyOperation::checkValid() const {
972   if (!valid) {
973     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
974   }
975 }
976 
977 void PyOperationBase::print(py::object fileObject, bool binary,
978                             llvm::Optional<int64_t> largeElementsLimit,
979                             bool enableDebugInfo, bool prettyDebugInfo,
980                             bool printGenericOpForm, bool useLocalScope,
981                             bool assumeVerified) {
982   PyOperation &operation = getOperation();
983   operation.checkValid();
984   if (fileObject.is_none())
985     fileObject = py::module::import("sys").attr("stdout");
986 
987   if (!assumeVerified && !printGenericOpForm &&
988       !mlirOperationVerify(operation)) {
989     std::string message("// Verification failed, printing generic form\n");
990     if (binary) {
991       fileObject.attr("write")(py::bytes(message));
992     } else {
993       fileObject.attr("write")(py::str(message));
994     }
995     printGenericOpForm = true;
996   }
997 
998   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
999   if (largeElementsLimit)
1000     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1001   if (enableDebugInfo)
1002     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1003   if (printGenericOpForm)
1004     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1005 
1006   PyFileAccumulator accum(fileObject, binary);
1007   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1008                               accum.getUserData());
1009   mlirOpPrintingFlagsDestroy(flags);
1010 }
1011 
1012 py::object PyOperationBase::getAsm(bool binary,
1013                                    llvm::Optional<int64_t> largeElementsLimit,
1014                                    bool enableDebugInfo, bool prettyDebugInfo,
1015                                    bool printGenericOpForm, bool useLocalScope,
1016                                    bool assumeVerified) {
1017   py::object fileObject;
1018   if (binary) {
1019     fileObject = py::module::import("io").attr("BytesIO")();
1020   } else {
1021     fileObject = py::module::import("io").attr("StringIO")();
1022   }
1023   print(fileObject, /*binary=*/binary,
1024         /*largeElementsLimit=*/largeElementsLimit,
1025         /*enableDebugInfo=*/enableDebugInfo,
1026         /*prettyDebugInfo=*/prettyDebugInfo,
1027         /*printGenericOpForm=*/printGenericOpForm,
1028         /*useLocalScope=*/useLocalScope,
1029         /*assumeVerified=*/assumeVerified);
1030 
1031   return fileObject.attr("getvalue")();
1032 }
1033 
1034 void PyOperationBase::moveAfter(PyOperationBase &other) {
1035   PyOperation &operation = getOperation();
1036   PyOperation &otherOp = other.getOperation();
1037   operation.checkValid();
1038   otherOp.checkValid();
1039   mlirOperationMoveAfter(operation, otherOp);
1040   operation.parentKeepAlive = otherOp.parentKeepAlive;
1041 }
1042 
1043 void PyOperationBase::moveBefore(PyOperationBase &other) {
1044   PyOperation &operation = getOperation();
1045   PyOperation &otherOp = other.getOperation();
1046   operation.checkValid();
1047   otherOp.checkValid();
1048   mlirOperationMoveBefore(operation, otherOp);
1049   operation.parentKeepAlive = otherOp.parentKeepAlive;
1050 }
1051 
1052 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1053   checkValid();
1054   if (!isAttached())
1055     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1056   MlirOperation operation = mlirOperationGetParentOperation(get());
1057   if (mlirOperationIsNull(operation))
1058     return {};
1059   return PyOperation::forOperation(getContext(), operation);
1060 }
1061 
1062 PyBlock PyOperation::getBlock() {
1063   checkValid();
1064   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1065   MlirBlock block = mlirOperationGetBlock(get());
1066   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1067   assert(parentOperation && "Operation has no parent");
1068   return PyBlock{std::move(*parentOperation), block};
1069 }
1070 
1071 py::object PyOperation::getCapsule() {
1072   checkValid();
1073   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1074 }
1075 
1076 py::object PyOperation::createFromCapsule(py::object capsule) {
1077   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1078   if (mlirOperationIsNull(rawOperation))
1079     throw py::error_already_set();
1080   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1081   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1082       .releaseObject();
1083 }
1084 
1085 static void maybeInsertOperation(PyOperationRef &op,
1086                                  const py::object &maybeIp) {
1087   // InsertPoint active?
1088   if (!maybeIp.is(py::cast(false))) {
1089     PyInsertionPoint *ip;
1090     if (maybeIp.is_none()) {
1091       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1092     } else {
1093       ip = py::cast<PyInsertionPoint *>(maybeIp);
1094     }
1095     if (ip)
1096       ip->insert(*op.get());
1097   }
1098 }
1099 
1100 py::object PyOperation::create(
1101     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1102     llvm::Optional<std::vector<PyValue *>> operands,
1103     llvm::Optional<py::dict> attributes,
1104     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1105     DefaultingPyLocation location, const py::object &maybeIp) {
1106   llvm::SmallVector<MlirValue, 4> mlirOperands;
1107   llvm::SmallVector<MlirType, 4> mlirResults;
1108   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1109   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1110 
1111   // General parameter validation.
1112   if (regions < 0)
1113     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1114 
1115   // Unpack/validate operands.
1116   if (operands) {
1117     mlirOperands.reserve(operands->size());
1118     for (PyValue *operand : *operands) {
1119       if (!operand)
1120         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1121       mlirOperands.push_back(operand->get());
1122     }
1123   }
1124 
1125   // Unpack/validate results.
1126   if (results) {
1127     mlirResults.reserve(results->size());
1128     for (PyType *result : *results) {
1129       // TODO: Verify result type originate from the same context.
1130       if (!result)
1131         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1132       mlirResults.push_back(*result);
1133     }
1134   }
1135   // Unpack/validate attributes.
1136   if (attributes) {
1137     mlirAttributes.reserve(attributes->size());
1138     for (auto &it : *attributes) {
1139       std::string key;
1140       try {
1141         key = it.first.cast<std::string>();
1142       } catch (py::cast_error &err) {
1143         std::string msg = "Invalid attribute key (not a string) when "
1144                           "attempting to create the operation \"" +
1145                           name + "\" (" + err.what() + ")";
1146         throw py::cast_error(msg);
1147       }
1148       try {
1149         auto &attribute = it.second.cast<PyAttribute &>();
1150         // TODO: Verify attribute originates from the same context.
1151         mlirAttributes.emplace_back(std::move(key), attribute);
1152       } catch (py::reference_cast_error &) {
1153         // This exception seems thrown when the value is "None".
1154         std::string msg =
1155             "Found an invalid (`None`?) attribute value for the key \"" + key +
1156             "\" when attempting to create the operation \"" + name + "\"";
1157         throw py::cast_error(msg);
1158       } catch (py::cast_error &err) {
1159         std::string msg = "Invalid attribute value for the key \"" + key +
1160                           "\" when attempting to create the operation \"" +
1161                           name + "\" (" + err.what() + ")";
1162         throw py::cast_error(msg);
1163       }
1164     }
1165   }
1166   // Unpack/validate successors.
1167   if (successors) {
1168     mlirSuccessors.reserve(successors->size());
1169     for (auto *successor : *successors) {
1170       // TODO: Verify successor originate from the same context.
1171       if (!successor)
1172         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1173       mlirSuccessors.push_back(successor->get());
1174     }
1175   }
1176 
1177   // Apply unpacked/validated to the operation state. Beyond this
1178   // point, exceptions cannot be thrown or else the state will leak.
1179   MlirOperationState state =
1180       mlirOperationStateGet(toMlirStringRef(name), location);
1181   if (!mlirOperands.empty())
1182     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1183                                   mlirOperands.data());
1184   if (!mlirResults.empty())
1185     mlirOperationStateAddResults(&state, mlirResults.size(),
1186                                  mlirResults.data());
1187   if (!mlirAttributes.empty()) {
1188     // Note that the attribute names directly reference bytes in
1189     // mlirAttributes, so that vector must not be changed from here
1190     // on.
1191     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1192     mlirNamedAttributes.reserve(mlirAttributes.size());
1193     for (auto &it : mlirAttributes)
1194       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1195           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1196                             toMlirStringRef(it.first)),
1197           it.second));
1198     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1199                                     mlirNamedAttributes.data());
1200   }
1201   if (!mlirSuccessors.empty())
1202     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1203                                     mlirSuccessors.data());
1204   if (regions) {
1205     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1206     mlirRegions.resize(regions);
1207     for (int i = 0; i < regions; ++i)
1208       mlirRegions[i] = mlirRegionCreate();
1209     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1210                                       mlirRegions.data());
1211   }
1212 
1213   // Construct the operation.
1214   MlirOperation operation = mlirOperationCreate(&state);
1215   PyOperationRef created =
1216       PyOperation::createDetached(location->getContext(), operation);
1217   maybeInsertOperation(created, maybeIp);
1218 
1219   return created->createOpView();
1220 }
1221 
1222 py::object PyOperation::clone(const py::object &maybeIp) {
1223   MlirOperation clonedOperation = mlirOperationClone(operation);
1224   PyOperationRef cloned =
1225       PyOperation::createDetached(getContext(), clonedOperation);
1226   maybeInsertOperation(cloned, maybeIp);
1227 
1228   return cloned->createOpView();
1229 }
1230 
1231 py::object PyOperation::createOpView() {
1232   checkValid();
1233   MlirIdentifier ident = mlirOperationGetName(get());
1234   MlirStringRef identStr = mlirIdentifierStr(ident);
1235   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1236       StringRef(identStr.data, identStr.length));
1237   if (opViewClass)
1238     return (*opViewClass)(getRef().getObject());
1239   return py::cast(PyOpView(getRef().getObject()));
1240 }
1241 
1242 void PyOperation::erase() {
1243   checkValid();
1244   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1245   // Python reference to a child operation is live. All children should also
1246   // have their `valid` bit set to false.
1247   auto &liveOperations = getContext()->liveOperations;
1248   if (liveOperations.count(operation.ptr))
1249     liveOperations.erase(operation.ptr);
1250   mlirOperationDestroy(operation);
1251   valid = false;
1252 }
1253 
1254 //------------------------------------------------------------------------------
1255 // PyOpView
1256 //------------------------------------------------------------------------------
1257 
1258 py::object PyOpView::buildGeneric(
1259     const py::object &cls, py::list resultTypeList, py::list operandList,
1260     llvm::Optional<py::dict> attributes,
1261     llvm::Optional<std::vector<PyBlock *>> successors,
1262     llvm::Optional<int> regions, DefaultingPyLocation location,
1263     const py::object &maybeIp) {
1264   PyMlirContextRef context = location->getContext();
1265   // Class level operation construction metadata.
1266   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1267   // Operand and result segment specs are either none, which does no
1268   // variadic unpacking, or a list of ints with segment sizes, where each
1269   // element is either a positive number (typically 1 for a scalar) or -1 to
1270   // indicate that it is derived from the length of the same-indexed operand
1271   // or result (implying that it is a list at that position).
1272   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1273   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1274 
1275   std::vector<uint32_t> operandSegmentLengths;
1276   std::vector<uint32_t> resultSegmentLengths;
1277 
1278   // Validate/determine region count.
1279   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1280   int opMinRegionCount = std::get<0>(opRegionSpec);
1281   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1282   if (!regions) {
1283     regions = opMinRegionCount;
1284   }
1285   if (*regions < opMinRegionCount) {
1286     throw py::value_error(
1287         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1288          llvm::Twine(opMinRegionCount) +
1289          " regions but was built with regions=" + llvm::Twine(*regions))
1290             .str());
1291   }
1292   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1293     throw py::value_error(
1294         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1295          llvm::Twine(opMinRegionCount) +
1296          " regions but was built with regions=" + llvm::Twine(*regions))
1297             .str());
1298   }
1299 
1300   // Unpack results.
1301   std::vector<PyType *> resultTypes;
1302   resultTypes.reserve(resultTypeList.size());
1303   if (resultSegmentSpecObj.is_none()) {
1304     // Non-variadic result unpacking.
1305     for (const auto &it : llvm::enumerate(resultTypeList)) {
1306       try {
1307         resultTypes.push_back(py::cast<PyType *>(it.value()));
1308         if (!resultTypes.back())
1309           throw py::cast_error();
1310       } catch (py::cast_error &err) {
1311         throw py::value_error((llvm::Twine("Result ") +
1312                                llvm::Twine(it.index()) + " of operation \"" +
1313                                name + "\" must be a Type (" + err.what() + ")")
1314                                   .str());
1315       }
1316     }
1317   } else {
1318     // Sized result unpacking.
1319     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1320     if (resultSegmentSpec.size() != resultTypeList.size()) {
1321       throw py::value_error((llvm::Twine("Operation \"") + name +
1322                              "\" requires " +
1323                              llvm::Twine(resultSegmentSpec.size()) +
1324                              " result segments but was provided " +
1325                              llvm::Twine(resultTypeList.size()))
1326                                 .str());
1327     }
1328     resultSegmentLengths.reserve(resultTypeList.size());
1329     for (const auto &it :
1330          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1331       int segmentSpec = std::get<1>(it.value());
1332       if (segmentSpec == 1 || segmentSpec == 0) {
1333         // Unpack unary element.
1334         try {
1335           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1336           if (resultType) {
1337             resultTypes.push_back(resultType);
1338             resultSegmentLengths.push_back(1);
1339           } else if (segmentSpec == 0) {
1340             // Allowed to be optional.
1341             resultSegmentLengths.push_back(0);
1342           } else {
1343             throw py::cast_error("was None and result is not optional");
1344           }
1345         } catch (py::cast_error &err) {
1346           throw py::value_error((llvm::Twine("Result ") +
1347                                  llvm::Twine(it.index()) + " of operation \"" +
1348                                  name + "\" must be a Type (" + err.what() +
1349                                  ")")
1350                                     .str());
1351         }
1352       } else if (segmentSpec == -1) {
1353         // Unpack sequence by appending.
1354         try {
1355           if (std::get<0>(it.value()).is_none()) {
1356             // Treat it as an empty list.
1357             resultSegmentLengths.push_back(0);
1358           } else {
1359             // Unpack the list.
1360             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1361             for (py::object segmentItem : segment) {
1362               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1363               if (!resultTypes.back()) {
1364                 throw py::cast_error("contained a None item");
1365               }
1366             }
1367             resultSegmentLengths.push_back(segment.size());
1368           }
1369         } catch (std::exception &err) {
1370           // NOTE: Sloppy to be using a catch-all here, but there are at least
1371           // three different unrelated exceptions that can be thrown in the
1372           // above "casts". Just keep the scope above small and catch them all.
1373           throw py::value_error((llvm::Twine("Result ") +
1374                                  llvm::Twine(it.index()) + " of operation \"" +
1375                                  name + "\" must be a Sequence of Types (" +
1376                                  err.what() + ")")
1377                                     .str());
1378         }
1379       } else {
1380         throw py::value_error("Unexpected segment spec");
1381       }
1382     }
1383   }
1384 
1385   // Unpack operands.
1386   std::vector<PyValue *> operands;
1387   operands.reserve(operands.size());
1388   if (operandSegmentSpecObj.is_none()) {
1389     // Non-sized operand unpacking.
1390     for (const auto &it : llvm::enumerate(operandList)) {
1391       try {
1392         operands.push_back(py::cast<PyValue *>(it.value()));
1393         if (!operands.back())
1394           throw py::cast_error();
1395       } catch (py::cast_error &err) {
1396         throw py::value_error((llvm::Twine("Operand ") +
1397                                llvm::Twine(it.index()) + " of operation \"" +
1398                                name + "\" must be a Value (" + err.what() + ")")
1399                                   .str());
1400       }
1401     }
1402   } else {
1403     // Sized operand unpacking.
1404     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1405     if (operandSegmentSpec.size() != operandList.size()) {
1406       throw py::value_error((llvm::Twine("Operation \"") + name +
1407                              "\" requires " +
1408                              llvm::Twine(operandSegmentSpec.size()) +
1409                              "operand segments but was provided " +
1410                              llvm::Twine(operandList.size()))
1411                                 .str());
1412     }
1413     operandSegmentLengths.reserve(operandList.size());
1414     for (const auto &it :
1415          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1416       int segmentSpec = std::get<1>(it.value());
1417       if (segmentSpec == 1 || segmentSpec == 0) {
1418         // Unpack unary element.
1419         try {
1420           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1421           if (operandValue) {
1422             operands.push_back(operandValue);
1423             operandSegmentLengths.push_back(1);
1424           } else if (segmentSpec == 0) {
1425             // Allowed to be optional.
1426             operandSegmentLengths.push_back(0);
1427           } else {
1428             throw py::cast_error("was None and operand is not optional");
1429           }
1430         } catch (py::cast_error &err) {
1431           throw py::value_error((llvm::Twine("Operand ") +
1432                                  llvm::Twine(it.index()) + " of operation \"" +
1433                                  name + "\" must be a Value (" + err.what() +
1434                                  ")")
1435                                     .str());
1436         }
1437       } else if (segmentSpec == -1) {
1438         // Unpack sequence by appending.
1439         try {
1440           if (std::get<0>(it.value()).is_none()) {
1441             // Treat it as an empty list.
1442             operandSegmentLengths.push_back(0);
1443           } else {
1444             // Unpack the list.
1445             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1446             for (py::object segmentItem : segment) {
1447               operands.push_back(py::cast<PyValue *>(segmentItem));
1448               if (!operands.back()) {
1449                 throw py::cast_error("contained a None item");
1450               }
1451             }
1452             operandSegmentLengths.push_back(segment.size());
1453           }
1454         } catch (std::exception &err) {
1455           // NOTE: Sloppy to be using a catch-all here, but there are at least
1456           // three different unrelated exceptions that can be thrown in the
1457           // above "casts". Just keep the scope above small and catch them all.
1458           throw py::value_error((llvm::Twine("Operand ") +
1459                                  llvm::Twine(it.index()) + " of operation \"" +
1460                                  name + "\" must be a Sequence of Values (" +
1461                                  err.what() + ")")
1462                                     .str());
1463         }
1464       } else {
1465         throw py::value_error("Unexpected segment spec");
1466       }
1467     }
1468   }
1469 
1470   // Merge operand/result segment lengths into attributes if needed.
1471   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1472     // Dup.
1473     if (attributes) {
1474       attributes = py::dict(*attributes);
1475     } else {
1476       attributes = py::dict();
1477     }
1478     if (attributes->contains("result_segment_sizes") ||
1479         attributes->contains("operand_segment_sizes")) {
1480       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1481                             "'operand_segment_sizes' attribute is unsupported. "
1482                             "Use Operation.create for such low-level access.");
1483     }
1484 
1485     // Add result_segment_sizes attribute.
1486     if (!resultSegmentLengths.empty()) {
1487       int64_t size = resultSegmentLengths.size();
1488       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1489           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1490           resultSegmentLengths.size(), resultSegmentLengths.data());
1491       (*attributes)["result_segment_sizes"] =
1492           PyAttribute(context, segmentLengthAttr);
1493     }
1494 
1495     // Add operand_segment_sizes attribute.
1496     if (!operandSegmentLengths.empty()) {
1497       int64_t size = operandSegmentLengths.size();
1498       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1499           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1500           operandSegmentLengths.size(), operandSegmentLengths.data());
1501       (*attributes)["operand_segment_sizes"] =
1502           PyAttribute(context, segmentLengthAttr);
1503     }
1504   }
1505 
1506   // Delegate to create.
1507   return PyOperation::create(name,
1508                              /*results=*/std::move(resultTypes),
1509                              /*operands=*/std::move(operands),
1510                              /*attributes=*/std::move(attributes),
1511                              /*successors=*/std::move(successors),
1512                              /*regions=*/*regions, location, maybeIp);
1513 }
1514 
1515 PyOpView::PyOpView(const py::object &operationObject)
1516     // Casting through the PyOperationBase base-class and then back to the
1517     // Operation lets us accept any PyOperationBase subclass.
1518     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1519       operationObject(operation.getRef().getObject()) {}
1520 
1521 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1522   // This is... a little gross. The typical pattern is to have a pure python
1523   // class that extends OpView like:
1524   //   class AddFOp(_cext.ir.OpView):
1525   //     def __init__(self, loc, lhs, rhs):
1526   //       operation = loc.context.create_operation(
1527   //           "addf", lhs, rhs, results=[lhs.type])
1528   //       super().__init__(operation)
1529   //
1530   // I.e. The goal of the user facing type is to provide a nice constructor
1531   // that has complete freedom for the op under construction. This is at odds
1532   // with our other desire to sometimes create this object by just passing an
1533   // operation (to initialize the base class). We could do *arg and **kwargs
1534   // munging to try to make it work, but instead, we synthesize a new class
1535   // on the fly which extends this user class (AddFOp in this example) and
1536   // *give it* the base class's __init__ method, thus bypassing the
1537   // intermediate subclass's __init__ method entirely. While slightly,
1538   // underhanded, this is safe/legal because the type hierarchy has not changed
1539   // (we just added a new leaf) and we aren't mucking around with __new__.
1540   // Typically, this new class will be stored on the original as "_Raw" and will
1541   // be used for casts and other things that need a variant of the class that
1542   // is initialized purely from an operation.
1543   py::object parentMetaclass =
1544       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1545   py::dict attributes;
1546   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1547   // now.
1548   //   auto opViewType = py::type::of<PyOpView>();
1549   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1550   attributes["__init__"] = opViewType.attr("__init__");
1551   py::str origName = userClass.attr("__name__");
1552   py::str newName = py::str("_") + origName;
1553   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1554 }
1555 
1556 //------------------------------------------------------------------------------
1557 // PyInsertionPoint.
1558 //------------------------------------------------------------------------------
1559 
1560 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1561 
1562 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1563     : refOperation(beforeOperationBase.getOperation().getRef()),
1564       block((*refOperation)->getBlock()) {}
1565 
1566 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1567   PyOperation &operation = operationBase.getOperation();
1568   if (operation.isAttached())
1569     throw SetPyError(PyExc_ValueError,
1570                      "Attempt to insert operation that is already attached");
1571   block.getParentOperation()->checkValid();
1572   MlirOperation beforeOp = {nullptr};
1573   if (refOperation) {
1574     // Insert before operation.
1575     (*refOperation)->checkValid();
1576     beforeOp = (*refOperation)->get();
1577   } else {
1578     // Insert at end (before null) is only valid if the block does not
1579     // already end in a known terminator (violating this will cause assertion
1580     // failures later).
1581     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1582       throw py::index_error("Cannot insert operation at the end of a block "
1583                             "that already has a terminator. Did you mean to "
1584                             "use 'InsertionPoint.at_block_terminator(block)' "
1585                             "versus 'InsertionPoint(block)'?");
1586     }
1587   }
1588   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1589   operation.setAttached();
1590 }
1591 
1592 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1593   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1594   if (mlirOperationIsNull(firstOp)) {
1595     // Just insert at end.
1596     return PyInsertionPoint(block);
1597   }
1598 
1599   // Insert before first op.
1600   PyOperationRef firstOpRef = PyOperation::forOperation(
1601       block.getParentOperation()->getContext(), firstOp);
1602   return PyInsertionPoint{block, std::move(firstOpRef)};
1603 }
1604 
1605 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1606   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1607   if (mlirOperationIsNull(terminator))
1608     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1609   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1610       block.getParentOperation()->getContext(), terminator);
1611   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1612 }
1613 
1614 py::object PyInsertionPoint::contextEnter() {
1615   return PyThreadContextEntry::pushInsertionPoint(*this);
1616 }
1617 
1618 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1619                                    const pybind11::object &excVal,
1620                                    const pybind11::object &excTb) {
1621   PyThreadContextEntry::popInsertionPoint(*this);
1622 }
1623 
1624 //------------------------------------------------------------------------------
1625 // PyAttribute.
1626 //------------------------------------------------------------------------------
1627 
1628 bool PyAttribute::operator==(const PyAttribute &other) {
1629   return mlirAttributeEqual(attr, other.attr);
1630 }
1631 
1632 py::object PyAttribute::getCapsule() {
1633   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1634 }
1635 
1636 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1637   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1638   if (mlirAttributeIsNull(rawAttr))
1639     throw py::error_already_set();
1640   return PyAttribute(
1641       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1642 }
1643 
1644 //------------------------------------------------------------------------------
1645 // PyNamedAttribute.
1646 //------------------------------------------------------------------------------
1647 
1648 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1649     : ownedName(new std::string(std::move(ownedName))) {
1650   namedAttr = mlirNamedAttributeGet(
1651       mlirIdentifierGet(mlirAttributeGetContext(attr),
1652                         toMlirStringRef(*this->ownedName)),
1653       attr);
1654 }
1655 
1656 //------------------------------------------------------------------------------
1657 // PyType.
1658 //------------------------------------------------------------------------------
1659 
1660 bool PyType::operator==(const PyType &other) {
1661   return mlirTypeEqual(type, other.type);
1662 }
1663 
1664 py::object PyType::getCapsule() {
1665   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1666 }
1667 
1668 PyType PyType::createFromCapsule(py::object capsule) {
1669   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1670   if (mlirTypeIsNull(rawType))
1671     throw py::error_already_set();
1672   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1673                 rawType);
1674 }
1675 
1676 //------------------------------------------------------------------------------
1677 // PyValue and subclases.
1678 //------------------------------------------------------------------------------
1679 
1680 pybind11::object PyValue::getCapsule() {
1681   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1682 }
1683 
1684 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1685   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1686   if (mlirValueIsNull(value))
1687     throw py::error_already_set();
1688   MlirOperation owner;
1689   if (mlirValueIsAOpResult(value))
1690     owner = mlirOpResultGetOwner(value);
1691   if (mlirValueIsABlockArgument(value))
1692     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1693   if (mlirOperationIsNull(owner))
1694     throw py::error_already_set();
1695   MlirContext ctx = mlirOperationGetContext(owner);
1696   PyOperationRef ownerRef =
1697       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1698   return PyValue(ownerRef, value);
1699 }
1700 
1701 //------------------------------------------------------------------------------
1702 // PySymbolTable.
1703 //------------------------------------------------------------------------------
1704 
1705 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1706     : operation(operation.getOperation().getRef()) {
1707   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1708   if (mlirSymbolTableIsNull(symbolTable)) {
1709     throw py::cast_error("Operation is not a Symbol Table.");
1710   }
1711 }
1712 
1713 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1714   operation->checkValid();
1715   MlirOperation symbol = mlirSymbolTableLookup(
1716       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1717   if (mlirOperationIsNull(symbol))
1718     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1719 
1720   return PyOperation::forOperation(operation->getContext(), symbol,
1721                                    operation.getObject())
1722       ->createOpView();
1723 }
1724 
1725 void PySymbolTable::erase(PyOperationBase &symbol) {
1726   operation->checkValid();
1727   symbol.getOperation().checkValid();
1728   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1729   // The operation is also erased, so we must invalidate it. There may be Python
1730   // references to this operation so we don't want to delete it from the list of
1731   // live operations here.
1732   symbol.getOperation().valid = false;
1733 }
1734 
1735 void PySymbolTable::dunderDel(const std::string &name) {
1736   py::object operation = dunderGetItem(name);
1737   erase(py::cast<PyOperationBase &>(operation));
1738 }
1739 
1740 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1741   operation->checkValid();
1742   symbol.getOperation().checkValid();
1743   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1744       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1745   if (mlirAttributeIsNull(symbolAttr))
1746     throw py::value_error("Expected operation to have a symbol name.");
1747   return PyAttribute(
1748       symbol.getOperation().getContext(),
1749       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1750 }
1751 
1752 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1753   // Op must already be a symbol.
1754   PyOperation &operation = symbol.getOperation();
1755   operation.checkValid();
1756   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1757   MlirAttribute existingNameAttr =
1758       mlirOperationGetAttributeByName(operation.get(), attrName);
1759   if (mlirAttributeIsNull(existingNameAttr))
1760     throw py::value_error("Expected operation to have a symbol name.");
1761   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1762 }
1763 
1764 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1765                                   const std::string &name) {
1766   // Op must already be a symbol.
1767   PyOperation &operation = symbol.getOperation();
1768   operation.checkValid();
1769   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1770   MlirAttribute existingNameAttr =
1771       mlirOperationGetAttributeByName(operation.get(), attrName);
1772   if (mlirAttributeIsNull(existingNameAttr))
1773     throw py::value_error("Expected operation to have a symbol name.");
1774   MlirAttribute newNameAttr =
1775       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1776   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1777 }
1778 
1779 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1780   PyOperation &operation = symbol.getOperation();
1781   operation.checkValid();
1782   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1783   MlirAttribute existingVisAttr =
1784       mlirOperationGetAttributeByName(operation.get(), attrName);
1785   if (mlirAttributeIsNull(existingVisAttr))
1786     throw py::value_error("Expected operation to have a symbol visibility.");
1787   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1788 }
1789 
1790 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1791                                   const std::string &visibility) {
1792   if (visibility != "public" && visibility != "private" &&
1793       visibility != "nested")
1794     throw py::value_error(
1795         "Expected visibility to be 'public', 'private' or 'nested'");
1796   PyOperation &operation = symbol.getOperation();
1797   operation.checkValid();
1798   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1799   MlirAttribute existingVisAttr =
1800       mlirOperationGetAttributeByName(operation.get(), attrName);
1801   if (mlirAttributeIsNull(existingVisAttr))
1802     throw py::value_error("Expected operation to have a symbol visibility.");
1803   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1804                                                toMlirStringRef(visibility));
1805   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1806 }
1807 
1808 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1809                                          const std::string &newSymbol,
1810                                          PyOperationBase &from) {
1811   PyOperation &fromOperation = from.getOperation();
1812   fromOperation.checkValid();
1813   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1814           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1815           from.getOperation())))
1816 
1817     throw py::value_error("Symbol rename failed");
1818 }
1819 
1820 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1821                                      bool allSymUsesVisible,
1822                                      py::object callback) {
1823   PyOperation &fromOperation = from.getOperation();
1824   fromOperation.checkValid();
1825   struct UserData {
1826     PyMlirContextRef context;
1827     py::object callback;
1828     bool gotException;
1829     std::string exceptionWhat;
1830     py::object exceptionType;
1831   };
1832   UserData userData{
1833       fromOperation.getContext(), std::move(callback), false, {}, {}};
1834   mlirSymbolTableWalkSymbolTables(
1835       fromOperation.get(), allSymUsesVisible,
1836       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1837         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1838         auto pyFoundOp =
1839             PyOperation::forOperation(calleeUserData->context, foundOp);
1840         if (calleeUserData->gotException)
1841           return;
1842         try {
1843           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1844         } catch (py::error_already_set &e) {
1845           calleeUserData->gotException = true;
1846           calleeUserData->exceptionWhat = e.what();
1847           calleeUserData->exceptionType = e.type();
1848         }
1849       },
1850       static_cast<void *>(&userData));
1851   if (userData.gotException) {
1852     std::string message("Exception raised in callback: ");
1853     message.append(userData.exceptionWhat);
1854     throw std::runtime_error(message);
1855   }
1856 }
1857 
1858 namespace {
1859 /// CRTP base class for Python MLIR values that subclass Value and should be
1860 /// castable from it. The value hierarchy is one level deep and is not supposed
1861 /// to accommodate other levels unless core MLIR changes.
1862 template <typename DerivedTy>
1863 class PyConcreteValue : public PyValue {
1864 public:
1865   // Derived classes must define statics for:
1866   //   IsAFunctionTy isaFunction
1867   //   const char *pyClassName
1868   // and redefine bindDerived.
1869   using ClassTy = py::class_<DerivedTy, PyValue>;
1870   using IsAFunctionTy = bool (*)(MlirValue);
1871 
1872   PyConcreteValue() = default;
1873   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1874       : PyValue(operationRef, value) {}
1875   PyConcreteValue(PyValue &orig)
1876       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1877 
1878   /// Attempts to cast the original value to the derived type and throws on
1879   /// type mismatches.
1880   static MlirValue castFrom(PyValue &orig) {
1881     if (!DerivedTy::isaFunction(orig.get())) {
1882       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1883       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1884                                              DerivedTy::pyClassName +
1885                                              " (from " + origRepr + ")");
1886     }
1887     return orig.get();
1888   }
1889 
1890   /// Binds the Python module objects to functions of this class.
1891   static void bind(py::module &m) {
1892     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1893     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1894     cls.def_static(
1895         "isinstance",
1896         [](PyValue &otherValue) -> bool {
1897           return DerivedTy::isaFunction(otherValue);
1898         },
1899         py::arg("other_value"));
1900     DerivedTy::bindDerived(cls);
1901   }
1902 
1903   /// Implemented by derived classes to add methods to the Python subclass.
1904   static void bindDerived(ClassTy &m) {}
1905 };
1906 
1907 /// Python wrapper for MlirBlockArgument.
1908 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1909 public:
1910   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1911   static constexpr const char *pyClassName = "BlockArgument";
1912   using PyConcreteValue::PyConcreteValue;
1913 
1914   static void bindDerived(ClassTy &c) {
1915     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1916       return PyBlock(self.getParentOperation(),
1917                      mlirBlockArgumentGetOwner(self.get()));
1918     });
1919     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1920       return mlirBlockArgumentGetArgNumber(self.get());
1921     });
1922     c.def(
1923         "set_type",
1924         [](PyBlockArgument &self, PyType type) {
1925           return mlirBlockArgumentSetType(self.get(), type);
1926         },
1927         py::arg("type"));
1928   }
1929 };
1930 
1931 /// Python wrapper for MlirOpResult.
1932 class PyOpResult : public PyConcreteValue<PyOpResult> {
1933 public:
1934   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1935   static constexpr const char *pyClassName = "OpResult";
1936   using PyConcreteValue::PyConcreteValue;
1937 
1938   static void bindDerived(ClassTy &c) {
1939     c.def_property_readonly("owner", [](PyOpResult &self) {
1940       assert(
1941           mlirOperationEqual(self.getParentOperation()->get(),
1942                              mlirOpResultGetOwner(self.get())) &&
1943           "expected the owner of the value in Python to match that in the IR");
1944       return self.getParentOperation().getObject();
1945     });
1946     c.def_property_readonly("result_number", [](PyOpResult &self) {
1947       return mlirOpResultGetResultNumber(self.get());
1948     });
1949   }
1950 };
1951 
1952 /// Returns the list of types of the values held by container.
1953 template <typename Container>
1954 static std::vector<PyType> getValueTypes(Container &container,
1955                                          PyMlirContextRef &context) {
1956   std::vector<PyType> result;
1957   result.reserve(container.getNumElements());
1958   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1959     result.push_back(
1960         PyType(context, mlirValueGetType(container.getElement(i).get())));
1961   }
1962   return result;
1963 }
1964 
1965 /// A list of block arguments. Internally, these are stored as consecutive
1966 /// elements, random access is cheap. The argument list is associated with the
1967 /// operation that contains the block (detached blocks are not allowed in
1968 /// Python bindings) and extends its lifetime.
1969 class PyBlockArgumentList
1970     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1971 public:
1972   static constexpr const char *pyClassName = "BlockArgumentList";
1973 
1974   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1975                       intptr_t startIndex = 0, intptr_t length = -1,
1976                       intptr_t step = 1)
1977       : Sliceable(startIndex,
1978                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1979                   step),
1980         operation(std::move(operation)), block(block) {}
1981 
1982   /// Returns the number of arguments in the list.
1983   intptr_t getNumElements() {
1984     operation->checkValid();
1985     return mlirBlockGetNumArguments(block);
1986   }
1987 
1988   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1989   PyBlockArgument getElement(intptr_t pos) {
1990     MlirValue argument = mlirBlockGetArgument(block, pos);
1991     return PyBlockArgument(operation, argument);
1992   }
1993 
1994   /// Returns a sublist of this list.
1995   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1996                             intptr_t step) {
1997     return PyBlockArgumentList(operation, block, startIndex, length, step);
1998   }
1999 
2000   static void bindDerived(ClassTy &c) {
2001     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2002       return getValueTypes(self, self.operation->getContext());
2003     });
2004   }
2005 
2006 private:
2007   PyOperationRef operation;
2008   MlirBlock block;
2009 };
2010 
2011 /// A list of operation operands. Internally, these are stored as consecutive
2012 /// elements, random access is cheap. The result list is associated with the
2013 /// operation whose results these are, and extends the lifetime of this
2014 /// operation.
2015 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2016 public:
2017   static constexpr const char *pyClassName = "OpOperandList";
2018 
2019   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2020                   intptr_t length = -1, intptr_t step = 1)
2021       : Sliceable(startIndex,
2022                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2023                                : length,
2024                   step),
2025         operation(operation) {}
2026 
2027   intptr_t getNumElements() {
2028     operation->checkValid();
2029     return mlirOperationGetNumOperands(operation->get());
2030   }
2031 
2032   PyValue getElement(intptr_t pos) {
2033     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2034     MlirOperation owner;
2035     if (mlirValueIsAOpResult(operand))
2036       owner = mlirOpResultGetOwner(operand);
2037     else if (mlirValueIsABlockArgument(operand))
2038       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2039     else
2040       assert(false && "Value must be an block arg or op result.");
2041     PyOperationRef pyOwner =
2042         PyOperation::forOperation(operation->getContext(), owner);
2043     return PyValue(pyOwner, operand);
2044   }
2045 
2046   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2047     return PyOpOperandList(operation, startIndex, length, step);
2048   }
2049 
2050   void dunderSetItem(intptr_t index, PyValue value) {
2051     index = wrapIndex(index);
2052     mlirOperationSetOperand(operation->get(), index, value.get());
2053   }
2054 
2055   static void bindDerived(ClassTy &c) {
2056     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2057   }
2058 
2059 private:
2060   PyOperationRef operation;
2061 };
2062 
2063 /// A list of operation results. Internally, these are stored as consecutive
2064 /// elements, random access is cheap. The result list is associated with the
2065 /// operation whose results these are, and extends the lifetime of this
2066 /// operation.
2067 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2068 public:
2069   static constexpr const char *pyClassName = "OpResultList";
2070 
2071   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2072                  intptr_t length = -1, intptr_t step = 1)
2073       : Sliceable(startIndex,
2074                   length == -1 ? mlirOperationGetNumResults(operation->get())
2075                                : length,
2076                   step),
2077         operation(operation) {}
2078 
2079   intptr_t getNumElements() {
2080     operation->checkValid();
2081     return mlirOperationGetNumResults(operation->get());
2082   }
2083 
2084   PyOpResult getElement(intptr_t index) {
2085     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2086     return PyOpResult(value);
2087   }
2088 
2089   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2090     return PyOpResultList(operation, startIndex, length, step);
2091   }
2092 
2093   static void bindDerived(ClassTy &c) {
2094     c.def_property_readonly("types", [](PyOpResultList &self) {
2095       return getValueTypes(self, self.operation->getContext());
2096     });
2097   }
2098 
2099 private:
2100   PyOperationRef operation;
2101 };
2102 
2103 /// A list of operation attributes. Can be indexed by name, producing
2104 /// attributes, or by index, producing named attributes.
2105 class PyOpAttributeMap {
2106 public:
2107   PyOpAttributeMap(PyOperationRef operation)
2108       : operation(std::move(operation)) {}
2109 
2110   PyAttribute dunderGetItemNamed(const std::string &name) {
2111     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2112                                                          toMlirStringRef(name));
2113     if (mlirAttributeIsNull(attr)) {
2114       throw SetPyError(PyExc_KeyError,
2115                        "attempt to access a non-existent attribute");
2116     }
2117     return PyAttribute(operation->getContext(), attr);
2118   }
2119 
2120   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2121     if (index < 0 || index >= dunderLen()) {
2122       throw SetPyError(PyExc_IndexError,
2123                        "attempt to access out of bounds attribute");
2124     }
2125     MlirNamedAttribute namedAttr =
2126         mlirOperationGetAttribute(operation->get(), index);
2127     return PyNamedAttribute(
2128         namedAttr.attribute,
2129         std::string(mlirIdentifierStr(namedAttr.name).data,
2130                     mlirIdentifierStr(namedAttr.name).length));
2131   }
2132 
2133   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2134     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2135                                     attr);
2136   }
2137 
2138   void dunderDelItem(const std::string &name) {
2139     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2140                                                      toMlirStringRef(name));
2141     if (!removed)
2142       throw SetPyError(PyExc_KeyError,
2143                        "attempt to delete a non-existent attribute");
2144   }
2145 
2146   intptr_t dunderLen() {
2147     return mlirOperationGetNumAttributes(operation->get());
2148   }
2149 
2150   bool dunderContains(const std::string &name) {
2151     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2152         operation->get(), toMlirStringRef(name)));
2153   }
2154 
2155   static void bind(py::module &m) {
2156     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2157         .def("__contains__", &PyOpAttributeMap::dunderContains)
2158         .def("__len__", &PyOpAttributeMap::dunderLen)
2159         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2160         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2161         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2162         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2163   }
2164 
2165 private:
2166   PyOperationRef operation;
2167 };
2168 
2169 } // namespace
2170 
2171 //------------------------------------------------------------------------------
2172 // Populates the core exports of the 'ir' submodule.
2173 //------------------------------------------------------------------------------
2174 
2175 void mlir::python::populateIRCore(py::module &m) {
2176   //----------------------------------------------------------------------------
2177   // Enums.
2178   //----------------------------------------------------------------------------
2179   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2180       .value("ERROR", MlirDiagnosticError)
2181       .value("WARNING", MlirDiagnosticWarning)
2182       .value("NOTE", MlirDiagnosticNote)
2183       .value("REMARK", MlirDiagnosticRemark);
2184 
2185   //----------------------------------------------------------------------------
2186   // Mapping of Diagnostics.
2187   //----------------------------------------------------------------------------
2188   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2189       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2190       .def_property_readonly("location", &PyDiagnostic::getLocation)
2191       .def_property_readonly("message", &PyDiagnostic::getMessage)
2192       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2193       .def("__str__", [](PyDiagnostic &self) -> py::str {
2194         if (!self.isValid())
2195           return "<Invalid Diagnostic>";
2196         return self.getMessage();
2197       });
2198 
2199   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2200       .def("detach", &PyDiagnosticHandler::detach)
2201       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2202       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2203       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2204       .def("__exit__", &PyDiagnosticHandler::contextExit);
2205 
2206   //----------------------------------------------------------------------------
2207   // Mapping of MlirContext.
2208   //----------------------------------------------------------------------------
2209   py::class_<PyMlirContext>(m, "Context", py::module_local())
2210       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2211       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2212       .def("_get_context_again",
2213            [](PyMlirContext &self) {
2214              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2215              return ref.releaseObject();
2216            })
2217       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2218       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2219       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2220       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2221                              &PyMlirContext::getCapsule)
2222       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2223       .def("__enter__", &PyMlirContext::contextEnter)
2224       .def("__exit__", &PyMlirContext::contextExit)
2225       .def_property_readonly_static(
2226           "current",
2227           [](py::object & /*class*/) {
2228             auto *context = PyThreadContextEntry::getDefaultContext();
2229             if (!context)
2230               throw SetPyError(PyExc_ValueError, "No current Context");
2231             return context;
2232           },
2233           "Gets the Context bound to the current thread or raises ValueError")
2234       .def_property_readonly(
2235           "dialects",
2236           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2237           "Gets a container for accessing dialects by name")
2238       .def_property_readonly(
2239           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2240           "Alias for 'dialect'")
2241       .def(
2242           "get_dialect_descriptor",
2243           [=](PyMlirContext &self, std::string &name) {
2244             MlirDialect dialect = mlirContextGetOrLoadDialect(
2245                 self.get(), {name.data(), name.size()});
2246             if (mlirDialectIsNull(dialect)) {
2247               throw SetPyError(PyExc_ValueError,
2248                                Twine("Dialect '") + name + "' not found");
2249             }
2250             return PyDialectDescriptor(self.getRef(), dialect);
2251           },
2252           py::arg("dialect_name"),
2253           "Gets or loads a dialect by name, returning its descriptor object")
2254       .def_property(
2255           "allow_unregistered_dialects",
2256           [](PyMlirContext &self) -> bool {
2257             return mlirContextGetAllowUnregisteredDialects(self.get());
2258           },
2259           [](PyMlirContext &self, bool value) {
2260             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2261           })
2262       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2263            py::arg("callback"),
2264            "Attaches a diagnostic handler that will receive callbacks")
2265       .def(
2266           "enable_multithreading",
2267           [](PyMlirContext &self, bool enable) {
2268             mlirContextEnableMultithreading(self.get(), enable);
2269           },
2270           py::arg("enable"))
2271       .def(
2272           "is_registered_operation",
2273           [](PyMlirContext &self, std::string &name) {
2274             return mlirContextIsRegisteredOperation(
2275                 self.get(), MlirStringRef{name.data(), name.size()});
2276           },
2277           py::arg("operation_name"));
2278 
2279   //----------------------------------------------------------------------------
2280   // Mapping of PyDialectDescriptor
2281   //----------------------------------------------------------------------------
2282   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2283       .def_property_readonly("namespace",
2284                              [](PyDialectDescriptor &self) {
2285                                MlirStringRef ns =
2286                                    mlirDialectGetNamespace(self.get());
2287                                return py::str(ns.data, ns.length);
2288                              })
2289       .def("__repr__", [](PyDialectDescriptor &self) {
2290         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2291         std::string repr("<DialectDescriptor ");
2292         repr.append(ns.data, ns.length);
2293         repr.append(">");
2294         return repr;
2295       });
2296 
2297   //----------------------------------------------------------------------------
2298   // Mapping of PyDialects
2299   //----------------------------------------------------------------------------
2300   py::class_<PyDialects>(m, "Dialects", py::module_local())
2301       .def("__getitem__",
2302            [=](PyDialects &self, std::string keyName) {
2303              MlirDialect dialect =
2304                  self.getDialectForKey(keyName, /*attrError=*/false);
2305              py::object descriptor =
2306                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2307              return createCustomDialectWrapper(keyName, std::move(descriptor));
2308            })
2309       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2310         MlirDialect dialect =
2311             self.getDialectForKey(attrName, /*attrError=*/true);
2312         py::object descriptor =
2313             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2314         return createCustomDialectWrapper(attrName, std::move(descriptor));
2315       });
2316 
2317   //----------------------------------------------------------------------------
2318   // Mapping of PyDialect
2319   //----------------------------------------------------------------------------
2320   py::class_<PyDialect>(m, "Dialect", py::module_local())
2321       .def(py::init<py::object>(), py::arg("descriptor"))
2322       .def_property_readonly(
2323           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2324       .def("__repr__", [](py::object self) {
2325         auto clazz = self.attr("__class__");
2326         return py::str("<Dialect ") +
2327                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2328                clazz.attr("__module__") + py::str(".") +
2329                clazz.attr("__name__") + py::str(")>");
2330       });
2331 
2332   //----------------------------------------------------------------------------
2333   // Mapping of Location
2334   //----------------------------------------------------------------------------
2335   py::class_<PyLocation>(m, "Location", py::module_local())
2336       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2337       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2338       .def("__enter__", &PyLocation::contextEnter)
2339       .def("__exit__", &PyLocation::contextExit)
2340       .def("__eq__",
2341            [](PyLocation &self, PyLocation &other) -> bool {
2342              return mlirLocationEqual(self, other);
2343            })
2344       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2345       .def_property_readonly_static(
2346           "current",
2347           [](py::object & /*class*/) {
2348             auto *loc = PyThreadContextEntry::getDefaultLocation();
2349             if (!loc)
2350               throw SetPyError(PyExc_ValueError, "No current Location");
2351             return loc;
2352           },
2353           "Gets the Location bound to the current thread or raises ValueError")
2354       .def_static(
2355           "unknown",
2356           [](DefaultingPyMlirContext context) {
2357             return PyLocation(context->getRef(),
2358                               mlirLocationUnknownGet(context->get()));
2359           },
2360           py::arg("context") = py::none(),
2361           "Gets a Location representing an unknown location")
2362       .def_static(
2363           "callsite",
2364           [](PyLocation callee, const std::vector<PyLocation> &frames,
2365              DefaultingPyMlirContext context) {
2366             if (frames.empty())
2367               throw py::value_error("No caller frames provided");
2368             MlirLocation caller = frames.back().get();
2369             for (const PyLocation &frame :
2370                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2371               caller = mlirLocationCallSiteGet(frame.get(), caller);
2372             return PyLocation(context->getRef(),
2373                               mlirLocationCallSiteGet(callee.get(), caller));
2374           },
2375           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2376           kContextGetCallSiteLocationDocstring)
2377       .def_static(
2378           "file",
2379           [](std::string filename, int line, int col,
2380              DefaultingPyMlirContext context) {
2381             return PyLocation(
2382                 context->getRef(),
2383                 mlirLocationFileLineColGet(
2384                     context->get(), toMlirStringRef(filename), line, col));
2385           },
2386           py::arg("filename"), py::arg("line"), py::arg("col"),
2387           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2388       .def_static(
2389           "fused",
2390           [](const std::vector<PyLocation> &pyLocations,
2391              llvm::Optional<PyAttribute> metadata,
2392              DefaultingPyMlirContext context) {
2393             llvm::SmallVector<MlirLocation, 4> locations;
2394             locations.reserve(pyLocations.size());
2395             for (auto &pyLocation : pyLocations)
2396               locations.push_back(pyLocation.get());
2397             MlirLocation location = mlirLocationFusedGet(
2398                 context->get(), locations.size(), locations.data(),
2399                 metadata ? metadata->get() : MlirAttribute{0});
2400             return PyLocation(context->getRef(), location);
2401           },
2402           py::arg("locations"), py::arg("metadata") = py::none(),
2403           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2404       .def_static(
2405           "name",
2406           [](std::string name, llvm::Optional<PyLocation> childLoc,
2407              DefaultingPyMlirContext context) {
2408             return PyLocation(
2409                 context->getRef(),
2410                 mlirLocationNameGet(
2411                     context->get(), toMlirStringRef(name),
2412                     childLoc ? childLoc->get()
2413                              : mlirLocationUnknownGet(context->get())));
2414           },
2415           py::arg("name"), py::arg("childLoc") = py::none(),
2416           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2417       .def_property_readonly(
2418           "context",
2419           [](PyLocation &self) { return self.getContext().getObject(); },
2420           "Context that owns the Location")
2421       .def(
2422           "emit_error",
2423           [](PyLocation &self, std::string message) {
2424             mlirEmitError(self, message.c_str());
2425           },
2426           py::arg("message"), "Emits an error at this location")
2427       .def("__repr__", [](PyLocation &self) {
2428         PyPrintAccumulator printAccum;
2429         mlirLocationPrint(self, printAccum.getCallback(),
2430                           printAccum.getUserData());
2431         return printAccum.join();
2432       });
2433 
2434   //----------------------------------------------------------------------------
2435   // Mapping of Module
2436   //----------------------------------------------------------------------------
2437   py::class_<PyModule>(m, "Module", py::module_local())
2438       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2439       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2440       .def_static(
2441           "parse",
2442           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2443             MlirModule module = mlirModuleCreateParse(
2444                 context->get(), toMlirStringRef(moduleAsm));
2445             // TODO: Rework error reporting once diagnostic engine is exposed
2446             // in C API.
2447             if (mlirModuleIsNull(module)) {
2448               throw SetPyError(
2449                   PyExc_ValueError,
2450                   "Unable to parse module assembly (see diagnostics)");
2451             }
2452             return PyModule::forModule(module).releaseObject();
2453           },
2454           py::arg("asm"), py::arg("context") = py::none(),
2455           kModuleParseDocstring)
2456       .def_static(
2457           "create",
2458           [](DefaultingPyLocation loc) {
2459             MlirModule module = mlirModuleCreateEmpty(loc);
2460             return PyModule::forModule(module).releaseObject();
2461           },
2462           py::arg("loc") = py::none(), "Creates an empty module")
2463       .def_property_readonly(
2464           "context",
2465           [](PyModule &self) { return self.getContext().getObject(); },
2466           "Context that created the Module")
2467       .def_property_readonly(
2468           "operation",
2469           [](PyModule &self) {
2470             return PyOperation::forOperation(self.getContext(),
2471                                              mlirModuleGetOperation(self.get()),
2472                                              self.getRef().releaseObject())
2473                 .releaseObject();
2474           },
2475           "Accesses the module as an operation")
2476       .def_property_readonly(
2477           "body",
2478           [](PyModule &self) {
2479             PyOperationRef moduleOp = PyOperation::forOperation(
2480                 self.getContext(), mlirModuleGetOperation(self.get()),
2481                 self.getRef().releaseObject());
2482             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2483             return returnBlock;
2484           },
2485           "Return the block for this module")
2486       .def(
2487           "dump",
2488           [](PyModule &self) {
2489             mlirOperationDump(mlirModuleGetOperation(self.get()));
2490           },
2491           kDumpDocstring)
2492       .def(
2493           "__str__",
2494           [](py::object self) {
2495             // Defer to the operation's __str__.
2496             return self.attr("operation").attr("__str__")();
2497           },
2498           kOperationStrDunderDocstring);
2499 
2500   //----------------------------------------------------------------------------
2501   // Mapping of Operation.
2502   //----------------------------------------------------------------------------
2503   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2504       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2505                              [](PyOperationBase &self) {
2506                                return self.getOperation().getCapsule();
2507                              })
2508       .def("__eq__",
2509            [](PyOperationBase &self, PyOperationBase &other) {
2510              return &self.getOperation() == &other.getOperation();
2511            })
2512       .def("__eq__",
2513            [](PyOperationBase &self, py::object other) { return false; })
2514       .def("__hash__",
2515            [](PyOperationBase &self) {
2516              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2517            })
2518       .def_property_readonly("attributes",
2519                              [](PyOperationBase &self) {
2520                                return PyOpAttributeMap(
2521                                    self.getOperation().getRef());
2522                              })
2523       .def_property_readonly("operands",
2524                              [](PyOperationBase &self) {
2525                                return PyOpOperandList(
2526                                    self.getOperation().getRef());
2527                              })
2528       .def_property_readonly("regions",
2529                              [](PyOperationBase &self) {
2530                                return PyRegionList(
2531                                    self.getOperation().getRef());
2532                              })
2533       .def_property_readonly(
2534           "results",
2535           [](PyOperationBase &self) {
2536             return PyOpResultList(self.getOperation().getRef());
2537           },
2538           "Returns the list of Operation results.")
2539       .def_property_readonly(
2540           "result",
2541           [](PyOperationBase &self) {
2542             auto &operation = self.getOperation();
2543             auto numResults = mlirOperationGetNumResults(operation);
2544             if (numResults != 1) {
2545               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2546               throw SetPyError(
2547                   PyExc_ValueError,
2548                   Twine("Cannot call .result on operation ") +
2549                       StringRef(name.data, name.length) + " which has " +
2550                       Twine(numResults) +
2551                       " results (it is only valid for operations with a "
2552                       "single result)");
2553             }
2554             return PyOpResult(operation.getRef(),
2555                               mlirOperationGetResult(operation, 0));
2556           },
2557           "Shortcut to get an op result if it has only one (throws an error "
2558           "otherwise).")
2559       .def_property_readonly(
2560           "location",
2561           [](PyOperationBase &self) {
2562             PyOperation &operation = self.getOperation();
2563             return PyLocation(operation.getContext(),
2564                               mlirOperationGetLocation(operation.get()));
2565           },
2566           "Returns the source location the operation was defined or derived "
2567           "from.")
2568       .def(
2569           "__str__",
2570           [](PyOperationBase &self) {
2571             return self.getAsm(/*binary=*/false,
2572                                /*largeElementsLimit=*/llvm::None,
2573                                /*enableDebugInfo=*/false,
2574                                /*prettyDebugInfo=*/false,
2575                                /*printGenericOpForm=*/false,
2576                                /*useLocalScope=*/false,
2577                                /*assumeVerified=*/false);
2578           },
2579           "Returns the assembly form of the operation.")
2580       .def("print", &PyOperationBase::print,
2581            // Careful: Lots of arguments must match up with print method.
2582            py::arg("file") = py::none(), py::arg("binary") = false,
2583            py::arg("large_elements_limit") = py::none(),
2584            py::arg("enable_debug_info") = false,
2585            py::arg("pretty_debug_info") = false,
2586            py::arg("print_generic_op_form") = false,
2587            py::arg("use_local_scope") = false,
2588            py::arg("assume_verified") = false, kOperationPrintDocstring)
2589       .def("get_asm", &PyOperationBase::getAsm,
2590            // Careful: Lots of arguments must match up with get_asm method.
2591            py::arg("binary") = false,
2592            py::arg("large_elements_limit") = py::none(),
2593            py::arg("enable_debug_info") = false,
2594            py::arg("pretty_debug_info") = false,
2595            py::arg("print_generic_op_form") = false,
2596            py::arg("use_local_scope") = false,
2597            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2598       .def(
2599           "verify",
2600           [](PyOperationBase &self) {
2601             return mlirOperationVerify(self.getOperation());
2602           },
2603           "Verify the operation and return true if it passes, false if it "
2604           "fails.")
2605       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2606            "Puts self immediately after the other operation in its parent "
2607            "block.")
2608       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2609            "Puts self immediately before the other operation in its parent "
2610            "block.")
2611       .def(
2612           "detach_from_parent",
2613           [](PyOperationBase &self) {
2614             PyOperation &operation = self.getOperation();
2615             operation.checkValid();
2616             if (!operation.isAttached())
2617               throw py::value_error("Detached operation has no parent.");
2618 
2619             operation.detachFromParent();
2620             return operation.createOpView();
2621           },
2622           "Detaches the operation from its parent block.");
2623 
2624   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2625       .def_static("create", &PyOperation::create, py::arg("name"),
2626                   py::arg("results") = py::none(),
2627                   py::arg("operands") = py::none(),
2628                   py::arg("attributes") = py::none(),
2629                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2630                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2631                   kOperationCreateDocstring)
2632       .def_property_readonly("parent",
2633                              [](PyOperation &self) -> py::object {
2634                                auto parent = self.getParentOperation();
2635                                if (parent)
2636                                  return parent->getObject();
2637                                return py::none();
2638                              })
2639       .def("erase", &PyOperation::erase)
2640       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2641       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2642                              &PyOperation::getCapsule)
2643       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2644       .def_property_readonly("name",
2645                              [](PyOperation &self) {
2646                                self.checkValid();
2647                                MlirOperation operation = self.get();
2648                                MlirStringRef name = mlirIdentifierStr(
2649                                    mlirOperationGetName(operation));
2650                                return py::str(name.data, name.length);
2651                              })
2652       .def_property_readonly(
2653           "context",
2654           [](PyOperation &self) {
2655             self.checkValid();
2656             return self.getContext().getObject();
2657           },
2658           "Context that owns the Operation")
2659       .def_property_readonly("opview", &PyOperation::createOpView);
2660 
2661   auto opViewClass =
2662       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2663           .def(py::init<py::object>(), py::arg("operation"))
2664           .def_property_readonly("operation", &PyOpView::getOperationObject)
2665           .def_property_readonly(
2666               "context",
2667               [](PyOpView &self) {
2668                 return self.getOperation().getContext().getObject();
2669               },
2670               "Context that owns the Operation")
2671           .def("__str__", [](PyOpView &self) {
2672             return py::str(self.getOperationObject());
2673           });
2674   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2675   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2676   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2677   opViewClass.attr("build_generic") = classmethod(
2678       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2679       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2680       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2681       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2682       "Builds a specific, generated OpView based on class level attributes.");
2683 
2684   //----------------------------------------------------------------------------
2685   // Mapping of PyRegion.
2686   //----------------------------------------------------------------------------
2687   py::class_<PyRegion>(m, "Region", py::module_local())
2688       .def_property_readonly(
2689           "blocks",
2690           [](PyRegion &self) {
2691             return PyBlockList(self.getParentOperation(), self.get());
2692           },
2693           "Returns a forward-optimized sequence of blocks.")
2694       .def_property_readonly(
2695           "owner",
2696           [](PyRegion &self) {
2697             return self.getParentOperation()->createOpView();
2698           },
2699           "Returns the operation owning this region.")
2700       .def(
2701           "__iter__",
2702           [](PyRegion &self) {
2703             self.checkValid();
2704             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2705             return PyBlockIterator(self.getParentOperation(), firstBlock);
2706           },
2707           "Iterates over blocks in the region.")
2708       .def("__eq__",
2709            [](PyRegion &self, PyRegion &other) {
2710              return self.get().ptr == other.get().ptr;
2711            })
2712       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2713 
2714   //----------------------------------------------------------------------------
2715   // Mapping of PyBlock.
2716   //----------------------------------------------------------------------------
2717   py::class_<PyBlock>(m, "Block", py::module_local())
2718       .def_property_readonly(
2719           "owner",
2720           [](PyBlock &self) {
2721             return self.getParentOperation()->createOpView();
2722           },
2723           "Returns the owning operation of this block.")
2724       .def_property_readonly(
2725           "region",
2726           [](PyBlock &self) {
2727             MlirRegion region = mlirBlockGetParentRegion(self.get());
2728             return PyRegion(self.getParentOperation(), region);
2729           },
2730           "Returns the owning region of this block.")
2731       .def_property_readonly(
2732           "arguments",
2733           [](PyBlock &self) {
2734             return PyBlockArgumentList(self.getParentOperation(), self.get());
2735           },
2736           "Returns a list of block arguments.")
2737       .def_property_readonly(
2738           "operations",
2739           [](PyBlock &self) {
2740             return PyOperationList(self.getParentOperation(), self.get());
2741           },
2742           "Returns a forward-optimized sequence of operations.")
2743       .def_static(
2744           "create_at_start",
2745           [](PyRegion &parent, py::list pyArgTypes) {
2746             parent.checkValid();
2747             llvm::SmallVector<MlirType, 4> argTypes;
2748             llvm::SmallVector<MlirLocation, 4> argLocs;
2749             argTypes.reserve(pyArgTypes.size());
2750             argLocs.reserve(pyArgTypes.size());
2751             for (auto &pyArg : pyArgTypes) {
2752               argTypes.push_back(pyArg.cast<PyType &>());
2753               // TODO: Pass in a proper location here.
2754               argLocs.push_back(
2755                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2756             }
2757 
2758             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2759                                               argLocs.data());
2760             mlirRegionInsertOwnedBlock(parent, 0, block);
2761             return PyBlock(parent.getParentOperation(), block);
2762           },
2763           py::arg("parent"), py::arg("arg_types") = py::list(),
2764           "Creates and returns a new Block at the beginning of the given "
2765           "region (with given argument types).")
2766       .def(
2767           "append_to",
2768           [](PyBlock &self, PyRegion &region) {
2769             MlirBlock b = self.get();
2770             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2771               mlirBlockDetach(b);
2772             mlirRegionAppendOwnedBlock(region.get(), b);
2773           },
2774           "Append this block to a region, transferring ownership if necessary")
2775       .def(
2776           "create_before",
2777           [](PyBlock &self, py::args pyArgTypes) {
2778             self.checkValid();
2779             llvm::SmallVector<MlirType, 4> argTypes;
2780             llvm::SmallVector<MlirLocation, 4> argLocs;
2781             argTypes.reserve(pyArgTypes.size());
2782             argLocs.reserve(pyArgTypes.size());
2783             for (auto &pyArg : pyArgTypes) {
2784               argTypes.push_back(pyArg.cast<PyType &>());
2785               // TODO: Pass in a proper location here.
2786               argLocs.push_back(
2787                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2788             }
2789 
2790             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2791                                               argLocs.data());
2792             MlirRegion region = mlirBlockGetParentRegion(self.get());
2793             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2794             return PyBlock(self.getParentOperation(), block);
2795           },
2796           "Creates and returns a new Block before this block "
2797           "(with given argument types).")
2798       .def(
2799           "create_after",
2800           [](PyBlock &self, py::args pyArgTypes) {
2801             self.checkValid();
2802             llvm::SmallVector<MlirType, 4> argTypes;
2803             llvm::SmallVector<MlirLocation, 4> argLocs;
2804             argTypes.reserve(pyArgTypes.size());
2805             argLocs.reserve(pyArgTypes.size());
2806             for (auto &pyArg : pyArgTypes) {
2807               argTypes.push_back(pyArg.cast<PyType &>());
2808 
2809               // TODO: Pass in a proper location here.
2810               argLocs.push_back(
2811                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2812             }
2813             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2814                                               argLocs.data());
2815             MlirRegion region = mlirBlockGetParentRegion(self.get());
2816             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2817             return PyBlock(self.getParentOperation(), block);
2818           },
2819           "Creates and returns a new Block after this block "
2820           "(with given argument types).")
2821       .def(
2822           "__iter__",
2823           [](PyBlock &self) {
2824             self.checkValid();
2825             MlirOperation firstOperation =
2826                 mlirBlockGetFirstOperation(self.get());
2827             return PyOperationIterator(self.getParentOperation(),
2828                                        firstOperation);
2829           },
2830           "Iterates over operations in the block.")
2831       .def("__eq__",
2832            [](PyBlock &self, PyBlock &other) {
2833              return self.get().ptr == other.get().ptr;
2834            })
2835       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2836       .def(
2837           "__str__",
2838           [](PyBlock &self) {
2839             self.checkValid();
2840             PyPrintAccumulator printAccum;
2841             mlirBlockPrint(self.get(), printAccum.getCallback(),
2842                            printAccum.getUserData());
2843             return printAccum.join();
2844           },
2845           "Returns the assembly form of the block.")
2846       .def(
2847           "append",
2848           [](PyBlock &self, PyOperationBase &operation) {
2849             if (operation.getOperation().isAttached())
2850               operation.getOperation().detachFromParent();
2851 
2852             MlirOperation mlirOperation = operation.getOperation().get();
2853             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2854             operation.getOperation().setAttached(
2855                 self.getParentOperation().getObject());
2856           },
2857           py::arg("operation"),
2858           "Appends an operation to this block. If the operation is currently "
2859           "in another block, it will be moved.");
2860 
2861   //----------------------------------------------------------------------------
2862   // Mapping of PyInsertionPoint.
2863   //----------------------------------------------------------------------------
2864 
2865   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2866       .def(py::init<PyBlock &>(), py::arg("block"),
2867            "Inserts after the last operation but still inside the block.")
2868       .def("__enter__", &PyInsertionPoint::contextEnter)
2869       .def("__exit__", &PyInsertionPoint::contextExit)
2870       .def_property_readonly_static(
2871           "current",
2872           [](py::object & /*class*/) {
2873             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2874             if (!ip)
2875               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2876             return ip;
2877           },
2878           "Gets the InsertionPoint bound to the current thread or raises "
2879           "ValueError if none has been set")
2880       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2881            "Inserts before a referenced operation.")
2882       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2883                   py::arg("block"), "Inserts at the beginning of the block.")
2884       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2885                   py::arg("block"), "Inserts before the block terminator.")
2886       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2887            "Inserts an operation.")
2888       .def_property_readonly(
2889           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2890           "Returns the block that this InsertionPoint points to.");
2891 
2892   //----------------------------------------------------------------------------
2893   // Mapping of PyAttribute.
2894   //----------------------------------------------------------------------------
2895   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2896       // Delegate to the PyAttribute copy constructor, which will also lifetime
2897       // extend the backing context which owns the MlirAttribute.
2898       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2899            "Casts the passed attribute to the generic Attribute")
2900       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2901                              &PyAttribute::getCapsule)
2902       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2903       .def_static(
2904           "parse",
2905           [](std::string attrSpec, DefaultingPyMlirContext context) {
2906             MlirAttribute type = mlirAttributeParseGet(
2907                 context->get(), toMlirStringRef(attrSpec));
2908             // TODO: Rework error reporting once diagnostic engine is exposed
2909             // in C API.
2910             if (mlirAttributeIsNull(type)) {
2911               throw SetPyError(PyExc_ValueError,
2912                                Twine("Unable to parse attribute: '") +
2913                                    attrSpec + "'");
2914             }
2915             return PyAttribute(context->getRef(), type);
2916           },
2917           py::arg("asm"), py::arg("context") = py::none(),
2918           "Parses an attribute from an assembly form")
2919       .def_property_readonly(
2920           "context",
2921           [](PyAttribute &self) { return self.getContext().getObject(); },
2922           "Context that owns the Attribute")
2923       .def_property_readonly("type",
2924                              [](PyAttribute &self) {
2925                                return PyType(self.getContext()->getRef(),
2926                                              mlirAttributeGetType(self));
2927                              })
2928       .def(
2929           "get_named",
2930           [](PyAttribute &self, std::string name) {
2931             return PyNamedAttribute(self, std::move(name));
2932           },
2933           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2934       .def("__eq__",
2935            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2936       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2937       .def("__hash__",
2938            [](PyAttribute &self) {
2939              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2940            })
2941       .def(
2942           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2943           kDumpDocstring)
2944       .def(
2945           "__str__",
2946           [](PyAttribute &self) {
2947             PyPrintAccumulator printAccum;
2948             mlirAttributePrint(self, printAccum.getCallback(),
2949                                printAccum.getUserData());
2950             return printAccum.join();
2951           },
2952           "Returns the assembly form of the Attribute.")
2953       .def("__repr__", [](PyAttribute &self) {
2954         // Generally, assembly formats are not printed for __repr__ because
2955         // this can cause exceptionally long debug output and exceptions.
2956         // However, attribute values are generally considered useful and are
2957         // printed. This may need to be re-evaluated if debug dumps end up
2958         // being excessive.
2959         PyPrintAccumulator printAccum;
2960         printAccum.parts.append("Attribute(");
2961         mlirAttributePrint(self, printAccum.getCallback(),
2962                            printAccum.getUserData());
2963         printAccum.parts.append(")");
2964         return printAccum.join();
2965       });
2966 
2967   //----------------------------------------------------------------------------
2968   // Mapping of PyNamedAttribute
2969   //----------------------------------------------------------------------------
2970   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2971       .def("__repr__",
2972            [](PyNamedAttribute &self) {
2973              PyPrintAccumulator printAccum;
2974              printAccum.parts.append("NamedAttribute(");
2975              printAccum.parts.append(
2976                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2977                          mlirIdentifierStr(self.namedAttr.name).length));
2978              printAccum.parts.append("=");
2979              mlirAttributePrint(self.namedAttr.attribute,
2980                                 printAccum.getCallback(),
2981                                 printAccum.getUserData());
2982              printAccum.parts.append(")");
2983              return printAccum.join();
2984            })
2985       .def_property_readonly(
2986           "name",
2987           [](PyNamedAttribute &self) {
2988             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2989                            mlirIdentifierStr(self.namedAttr.name).length);
2990           },
2991           "The name of the NamedAttribute binding")
2992       .def_property_readonly(
2993           "attr",
2994           [](PyNamedAttribute &self) {
2995             // TODO: When named attribute is removed/refactored, also remove
2996             // this constructor (it does an inefficient table lookup).
2997             auto contextRef = PyMlirContext::forContext(
2998                 mlirAttributeGetContext(self.namedAttr.attribute));
2999             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3000           },
3001           py::keep_alive<0, 1>(),
3002           "The underlying generic attribute of the NamedAttribute binding");
3003 
3004   //----------------------------------------------------------------------------
3005   // Mapping of PyType.
3006   //----------------------------------------------------------------------------
3007   py::class_<PyType>(m, "Type", py::module_local())
3008       // Delegate to the PyType copy constructor, which will also lifetime
3009       // extend the backing context which owns the MlirType.
3010       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3011            "Casts the passed type to the generic Type")
3012       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3013       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3014       .def_static(
3015           "parse",
3016           [](std::string typeSpec, DefaultingPyMlirContext context) {
3017             MlirType type =
3018                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3019             // TODO: Rework error reporting once diagnostic engine is exposed
3020             // in C API.
3021             if (mlirTypeIsNull(type)) {
3022               throw SetPyError(PyExc_ValueError,
3023                                Twine("Unable to parse type: '") + typeSpec +
3024                                    "'");
3025             }
3026             return PyType(context->getRef(), type);
3027           },
3028           py::arg("asm"), py::arg("context") = py::none(),
3029           kContextParseTypeDocstring)
3030       .def_property_readonly(
3031           "context", [](PyType &self) { return self.getContext().getObject(); },
3032           "Context that owns the Type")
3033       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3034       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3035       .def("__hash__",
3036            [](PyType &self) {
3037              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3038            })
3039       .def(
3040           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3041       .def(
3042           "__str__",
3043           [](PyType &self) {
3044             PyPrintAccumulator printAccum;
3045             mlirTypePrint(self, printAccum.getCallback(),
3046                           printAccum.getUserData());
3047             return printAccum.join();
3048           },
3049           "Returns the assembly form of the type.")
3050       .def("__repr__", [](PyType &self) {
3051         // Generally, assembly formats are not printed for __repr__ because
3052         // this can cause exceptionally long debug output and exceptions.
3053         // However, types are an exception as they typically have compact
3054         // assembly forms and printing them is useful.
3055         PyPrintAccumulator printAccum;
3056         printAccum.parts.append("Type(");
3057         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3058         printAccum.parts.append(")");
3059         return printAccum.join();
3060       });
3061 
3062   //----------------------------------------------------------------------------
3063   // Mapping of Value.
3064   //----------------------------------------------------------------------------
3065   py::class_<PyValue>(m, "Value", py::module_local())
3066       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3067       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3068       .def_property_readonly(
3069           "context",
3070           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3071           "Context in which the value lives.")
3072       .def(
3073           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3074           kDumpDocstring)
3075       .def_property_readonly(
3076           "owner",
3077           [](PyValue &self) {
3078             assert(mlirOperationEqual(self.getParentOperation()->get(),
3079                                       mlirOpResultGetOwner(self.get())) &&
3080                    "expected the owner of the value in Python to match that in "
3081                    "the IR");
3082             return self.getParentOperation().getObject();
3083           })
3084       .def("__eq__",
3085            [](PyValue &self, PyValue &other) {
3086              return self.get().ptr == other.get().ptr;
3087            })
3088       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3089       .def("__hash__",
3090            [](PyValue &self) {
3091              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3092            })
3093       .def(
3094           "__str__",
3095           [](PyValue &self) {
3096             PyPrintAccumulator printAccum;
3097             printAccum.parts.append("Value(");
3098             mlirValuePrint(self.get(), printAccum.getCallback(),
3099                            printAccum.getUserData());
3100             printAccum.parts.append(")");
3101             return printAccum.join();
3102           },
3103           kValueDunderStrDocstring)
3104       .def_property_readonly("type", [](PyValue &self) {
3105         return PyType(self.getParentOperation()->getContext(),
3106                       mlirValueGetType(self.get()));
3107       });
3108   PyBlockArgument::bind(m);
3109   PyOpResult::bind(m);
3110 
3111   //----------------------------------------------------------------------------
3112   // Mapping of SymbolTable.
3113   //----------------------------------------------------------------------------
3114   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3115       .def(py::init<PyOperationBase &>())
3116       .def("__getitem__", &PySymbolTable::dunderGetItem)
3117       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3118       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3119       .def("__delitem__", &PySymbolTable::dunderDel)
3120       .def("__contains__",
3121            [](PySymbolTable &table, const std::string &name) {
3122              return !mlirOperationIsNull(mlirSymbolTableLookup(
3123                  table, mlirStringRefCreate(name.data(), name.length())));
3124            })
3125       // Static helpers.
3126       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3127                   py::arg("symbol"), py::arg("name"))
3128       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3129                   py::arg("symbol"))
3130       .def_static("get_visibility", &PySymbolTable::getVisibility,
3131                   py::arg("symbol"))
3132       .def_static("set_visibility", &PySymbolTable::setVisibility,
3133                   py::arg("symbol"), py::arg("visibility"))
3134       .def_static("replace_all_symbol_uses",
3135                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3136                   py::arg("new_symbol"), py::arg("from_op"))
3137       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3138                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3139                   py::arg("callback"));
3140 
3141   // Container bindings.
3142   PyBlockArgumentList::bind(m);
3143   PyBlockIterator::bind(m);
3144   PyBlockList::bind(m);
3145   PyOperationIterator::bind(m);
3146   PyOperationList::bind(m);
3147   PyOpAttributeMap::bind(m);
3148   PyOpOperandList::bind(m);
3149   PyOpResultList::bind(m);
3150   PyRegionIterator::bind(m);
3151   PyRegionList::bind(m);
3152 
3153   // Debug bindings.
3154   PyGlobalDebugFlag::bind(m);
3155 }
3156