1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 #include <utility>
24 
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32 
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36 
37 static const char kContextParseTypeDocstring[] =
38     R"(Parses the assembly form of a type.
39 
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41 
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44 
45 static const char kContextGetCallSiteLocationDocstring[] =
46     R"(Gets a Location representing a caller and callsite)";
47 
48 static const char kContextGetFileLocationDocstring[] =
49     R"(Gets a Location representing a file, line and column)";
50 
51 static const char kContextGetFusedLocationDocstring[] =
52     R"(Gets a Location representing a fused location with optional metadata)";
53 
54 static const char kContextGetNameLocationDocString[] =
55     R"(Gets a Location representing a named location with optional child location)";
56 
57 static const char kModuleParseDocstring[] =
58     R"(Parses a module's assembly format from a string.
59 
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61 
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64 
65 static const char kOperationCreateDocstring[] =
66     R"(Creates a new operation.
67 
68 Args:
69   name: Operation name (e.g. "dialect.operation").
70   results: Sequence of Type representing op result types.
71   attributes: Dict of str:Attribute.
72   successors: List of Block for the operation's successors.
73   regions: Number of regions to create.
74   location: A Location object (defaults to resolve from context manager).
75   ip: An InsertionPoint (defaults to resolve from context manager or set to
76     False to disable insertion, even with an insertion point set in the
77     context manager).
78 Returns:
79   A new "detached" Operation object. Detached operations can be added
80   to blocks, which causes them to become "attached."
81 )";
82 
83 static const char kOperationPrintDocstring[] =
84     R"(Prints the assembly form of the operation to a file like object.
85 
86 Args:
87   file: The file like object to write to. Defaults to sys.stdout.
88   binary: Whether to write bytes (True) or str (False). Defaults to False.
89   large_elements_limit: Whether to elide elements attributes above this
90     number of elements. Defaults to None (no limit).
91   enable_debug_info: Whether to print debug/location information. Defaults
92     to False.
93   pretty_debug_info: Whether to format debug information for easier reading
94     by a human (warning: the result is unparseable).
95   print_generic_op_form: Whether to print the generic assembly forms of all
96     ops. Defaults to False.
97   use_local_Scope: Whether to print in a way that is more optimized for
98     multi-threaded access but may not be consistent with how the overall
99     module prints.
100   assume_verified: By default, if not printing generic form, the verifier
101     will be run and if it fails, generic form will be printed with a comment
102     about failed verification. While a reasonable default for interactive use,
103     for systematic use, it is often better for the caller to verify explicitly
104     and report failures in a more robust fashion. Set this to True if doing this
105     in order to avoid running a redundant verification. If the IR is actually
106     invalid, behavior is undefined.
107 )";
108 
109 static const char kOperationGetAsmDocstring[] =
110     R"(Gets the assembly form of the operation with all options available.
111 
112 Args:
113   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114     False.
115   ... others ...: See the print() method for common keyword arguments for
116     configuring the printout.
117 Returns:
118   Either a bytes or str object, depending on the setting of the 'binary'
119   argument.
120 )";
121 
122 static const char kOperationStrDunderDocstring[] =
123     R"(Gets the assembly form of the operation with default options.
124 
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129 
130 static const char kDumpDocstring[] =
131     R"(Dumps a debug representation of the object to stderr.)";
132 
133 static const char kAppendBlockDocstring[] =
134     R"(Appends a new block, with argument types as positional args.
135 
136 Returns:
137   The created block.
138 )";
139 
140 static const char kValueDunderStrDocstring[] =
141     R"(Returns the string form of the value.
142 
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147 
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151 
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
154 py::object classmethod(Func f, Args... args) {
155   py::object cf = py::cpp_function(f, args...);
156   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158 
159 static py::object
160 createCustomDialectWrapper(const std::string &dialectNamespace,
161                            py::object dialectDescriptor) {
162   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163   if (!dialectClass) {
164     // Use the base class.
165     return py::cast(PyDialect(std::move(dialectDescriptor)));
166   }
167 
168   // Create the custom implementation.
169   return (*dialectClass)(std::move(dialectDescriptor));
170 }
171 
172 static MlirStringRef toMlirStringRef(const std::string &s) {
173   return mlirStringRefCreate(s.data(), s.size());
174 }
175 
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
178   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179 
180   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181 
182   static void bind(py::module &m) {
183     // Debug flags.
184     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185         .def_property_static("flag", &PyGlobalDebugFlag::get,
186                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187   }
188 };
189 
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193 
194 namespace {
195 
196 class PyRegionIterator {
197 public:
198   PyRegionIterator(PyOperationRef operation)
199       : operation(std::move(operation)) {}
200 
201   PyRegionIterator &dunderIter() { return *this; }
202 
203   PyRegion dunderNext() {
204     operation->checkValid();
205     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206       throw py::stop_iteration();
207     }
208     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209     return PyRegion(operation, region);
210   }
211 
212   static void bind(py::module &m) {
213     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214         .def("__iter__", &PyRegionIterator::dunderIter)
215         .def("__next__", &PyRegionIterator::dunderNext);
216   }
217 
218 private:
219   PyOperationRef operation;
220   int nextIndex = 0;
221 };
222 
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
227   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228 
229   intptr_t dunderLen() {
230     operation->checkValid();
231     return mlirOperationGetNumRegions(operation->get());
232   }
233 
234   PyRegion dunderGetItem(intptr_t index) {
235     // dunderLen checks validity.
236     if (index < 0 || index >= dunderLen()) {
237       throw SetPyError(PyExc_IndexError,
238                        "attempt to access out of bounds region");
239     }
240     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241     return PyRegion(operation, region);
242   }
243 
244   static void bind(py::module &m) {
245     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246         .def("__len__", &PyRegionList::dunderLen)
247         .def("__getitem__", &PyRegionList::dunderGetItem);
248   }
249 
250 private:
251   PyOperationRef operation;
252 };
253 
254 class PyBlockIterator {
255 public:
256   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257       : operation(std::move(operation)), next(next) {}
258 
259   PyBlockIterator &dunderIter() { return *this; }
260 
261   PyBlock dunderNext() {
262     operation->checkValid();
263     if (mlirBlockIsNull(next)) {
264       throw py::stop_iteration();
265     }
266 
267     PyBlock returnBlock(operation, next);
268     next = mlirBlockGetNextInRegion(next);
269     return returnBlock;
270   }
271 
272   static void bind(py::module &m) {
273     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274         .def("__iter__", &PyBlockIterator::dunderIter)
275         .def("__next__", &PyBlockIterator::dunderNext);
276   }
277 
278 private:
279   PyOperationRef operation;
280   MlirBlock next;
281 };
282 
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
288   PyBlockList(PyOperationRef operation, MlirRegion region)
289       : operation(std::move(operation)), region(region) {}
290 
291   PyBlockIterator dunderIter() {
292     operation->checkValid();
293     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294   }
295 
296   intptr_t dunderLen() {
297     operation->checkValid();
298     intptr_t count = 0;
299     MlirBlock block = mlirRegionGetFirstBlock(region);
300     while (!mlirBlockIsNull(block)) {
301       count += 1;
302       block = mlirBlockGetNextInRegion(block);
303     }
304     return count;
305   }
306 
307   PyBlock dunderGetItem(intptr_t index) {
308     operation->checkValid();
309     if (index < 0) {
310       throw SetPyError(PyExc_IndexError,
311                        "attempt to access out of bounds block");
312     }
313     MlirBlock block = mlirRegionGetFirstBlock(region);
314     while (!mlirBlockIsNull(block)) {
315       if (index == 0) {
316         return PyBlock(operation, block);
317       }
318       block = mlirBlockGetNextInRegion(block);
319       index -= 1;
320     }
321     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322   }
323 
324   PyBlock appendBlock(const py::args &pyArgTypes) {
325     operation->checkValid();
326     llvm::SmallVector<MlirType, 4> argTypes;
327     llvm::SmallVector<MlirLocation, 4> argLocs;
328     argTypes.reserve(pyArgTypes.size());
329     argLocs.reserve(pyArgTypes.size());
330     for (auto &pyArg : pyArgTypes) {
331       argTypes.push_back(pyArg.cast<PyType &>());
332       // TODO: Pass in a proper location here.
333       argLocs.push_back(
334           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335     }
336 
337     MlirBlock block =
338         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339     mlirRegionAppendOwnedBlock(region, block);
340     return PyBlock(operation, block);
341   }
342 
343   static void bind(py::module &m) {
344     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345         .def("__getitem__", &PyBlockList::dunderGetItem)
346         .def("__iter__", &PyBlockList::dunderIter)
347         .def("__len__", &PyBlockList::dunderLen)
348         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349   }
350 
351 private:
352   PyOperationRef operation;
353   MlirRegion region;
354 };
355 
356 class PyOperationIterator {
357 public:
358   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359       : parentOperation(std::move(parentOperation)), next(next) {}
360 
361   PyOperationIterator &dunderIter() { return *this; }
362 
363   py::object dunderNext() {
364     parentOperation->checkValid();
365     if (mlirOperationIsNull(next)) {
366       throw py::stop_iteration();
367     }
368 
369     PyOperationRef returnOperation =
370         PyOperation::forOperation(parentOperation->getContext(), next);
371     next = mlirOperationGetNextInBlock(next);
372     return returnOperation->createOpView();
373   }
374 
375   static void bind(py::module &m) {
376     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377         .def("__iter__", &PyOperationIterator::dunderIter)
378         .def("__next__", &PyOperationIterator::dunderNext);
379   }
380 
381 private:
382   PyOperationRef parentOperation;
383   MlirOperation next;
384 };
385 
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
392   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393       : parentOperation(std::move(parentOperation)), block(block) {}
394 
395   PyOperationIterator dunderIter() {
396     parentOperation->checkValid();
397     return PyOperationIterator(parentOperation,
398                                mlirBlockGetFirstOperation(block));
399   }
400 
401   intptr_t dunderLen() {
402     parentOperation->checkValid();
403     intptr_t count = 0;
404     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405     while (!mlirOperationIsNull(childOp)) {
406       count += 1;
407       childOp = mlirOperationGetNextInBlock(childOp);
408     }
409     return count;
410   }
411 
412   py::object dunderGetItem(intptr_t index) {
413     parentOperation->checkValid();
414     if (index < 0) {
415       throw SetPyError(PyExc_IndexError,
416                        "attempt to access out of bounds operation");
417     }
418     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419     while (!mlirOperationIsNull(childOp)) {
420       if (index == 0) {
421         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422             ->createOpView();
423       }
424       childOp = mlirOperationGetNextInBlock(childOp);
425       index -= 1;
426     }
427     throw SetPyError(PyExc_IndexError,
428                      "attempt to access out of bounds operation");
429   }
430 
431   static void bind(py::module &m) {
432     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433         .def("__getitem__", &PyOperationList::dunderGetItem)
434         .def("__iter__", &PyOperationList::dunderIter)
435         .def("__len__", &PyOperationList::dunderLen);
436   }
437 
438 private:
439   PyOperationRef parentOperation;
440   MlirBlock block;
441 };
442 
443 } // namespace
444 
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448 
449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450   py::gil_scoped_acquire acquire;
451   auto &liveContexts = getLiveContexts();
452   liveContexts[context.ptr] = this;
453 }
454 
455 PyMlirContext::~PyMlirContext() {
456   // Note that the only public way to construct an instance is via the
457   // forContext method, which always puts the associated handle into
458   // liveContexts.
459   py::gil_scoped_acquire acquire;
460   getLiveContexts().erase(context.ptr);
461   mlirContextDestroy(context);
462 }
463 
464 py::object PyMlirContext::getCapsule() {
465   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467 
468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470   if (mlirContextIsNull(rawContext))
471     throw py::error_already_set();
472   return forContext(rawContext).releaseObject();
473 }
474 
475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476   MlirContext context = mlirContextCreate();
477   mlirRegisterAllDialects(context);
478   return new PyMlirContext(context);
479 }
480 
481 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
482   py::gil_scoped_acquire acquire;
483   auto &liveContexts = getLiveContexts();
484   auto it = liveContexts.find(context.ptr);
485   if (it == liveContexts.end()) {
486     // Create.
487     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
488     py::object pyRef = py::cast(unownedContextWrapper);
489     assert(pyRef && "cast to py::object failed");
490     liveContexts[context.ptr] = unownedContextWrapper;
491     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
492   }
493   // Use existing.
494   py::object pyRef = py::cast(it->second);
495   return PyMlirContextRef(it->second, std::move(pyRef));
496 }
497 
498 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
499   static LiveContextMap liveContexts;
500   return liveContexts;
501 }
502 
503 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
504 
505 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
506 
507 size_t PyMlirContext::clearLiveOperations() {
508   for (auto &op : liveOperations)
509     op.second.second->setInvalid();
510   size_t numInvalidated = liveOperations.size();
511   liveOperations.clear();
512   return numInvalidated;
513 }
514 
515 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
516 
517 pybind11::object PyMlirContext::contextEnter() {
518   return PyThreadContextEntry::pushContext(*this);
519 }
520 
521 void PyMlirContext::contextExit(const pybind11::object &excType,
522                                 const pybind11::object &excVal,
523                                 const pybind11::object &excTb) {
524   PyThreadContextEntry::popContext(*this);
525 }
526 
527 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
528   // Note that ownership is transferred to the delete callback below by way of
529   // an explicit inc_ref (borrow).
530   PyDiagnosticHandler *pyHandler =
531       new PyDiagnosticHandler(get(), std::move(callback));
532   py::object pyHandlerObject =
533       py::cast(pyHandler, py::return_value_policy::take_ownership);
534   pyHandlerObject.inc_ref();
535 
536   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
537   // guaranteed to be known to pybind.
538   auto handlerCallback =
539       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
540     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
541     py::object pyDiagnosticObject =
542         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
543 
544     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
545     bool result = false;
546     {
547       // Since this can be called from arbitrary C++ contexts, always get the
548       // gil.
549       py::gil_scoped_acquire gil;
550       try {
551         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
552       } catch (std::exception &e) {
553         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
554                 e.what());
555         pyHandler->hadError = true;
556       }
557     }
558 
559     pyDiagnostic->invalidate();
560     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
561   };
562   auto deleteCallback = +[](void *userData) {
563     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
564     assert(pyHandler->registeredID && "handler is not registered");
565     pyHandler->registeredID.reset();
566 
567     // Decrement reference, balancing the inc_ref() above.
568     py::object pyHandlerObject =
569         py::cast(pyHandler, py::return_value_policy::reference);
570     pyHandlerObject.dec_ref();
571   };
572 
573   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
574       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
575   return pyHandlerObject;
576 }
577 
578 PyMlirContext &DefaultingPyMlirContext::resolve() {
579   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
580   if (!context) {
581     throw SetPyError(
582         PyExc_RuntimeError,
583         "An MLIR function requires a Context but none was provided in the call "
584         "or from the surrounding environment. Either pass to the function with "
585         "a 'context=' argument or establish a default using 'with Context():'");
586   }
587   return *context;
588 }
589 
590 //------------------------------------------------------------------------------
591 // PyThreadContextEntry management
592 //------------------------------------------------------------------------------
593 
594 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
595   static thread_local std::vector<PyThreadContextEntry> stack;
596   return stack;
597 }
598 
599 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
600   auto &stack = getStack();
601   if (stack.empty())
602     return nullptr;
603   return &stack.back();
604 }
605 
606 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
607                                 py::object insertionPoint,
608                                 py::object location) {
609   auto &stack = getStack();
610   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
611                      std::move(location));
612   // If the new stack has more than one entry and the context of the new top
613   // entry matches the previous, copy the insertionPoint and location from the
614   // previous entry if missing from the new top entry.
615   if (stack.size() > 1) {
616     auto &prev = *(stack.rbegin() + 1);
617     auto &current = stack.back();
618     if (current.context.is(prev.context)) {
619       // Default non-context objects from the previous entry.
620       if (!current.insertionPoint)
621         current.insertionPoint = prev.insertionPoint;
622       if (!current.location)
623         current.location = prev.location;
624     }
625   }
626 }
627 
628 PyMlirContext *PyThreadContextEntry::getContext() {
629   if (!context)
630     return nullptr;
631   return py::cast<PyMlirContext *>(context);
632 }
633 
634 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
635   if (!insertionPoint)
636     return nullptr;
637   return py::cast<PyInsertionPoint *>(insertionPoint);
638 }
639 
640 PyLocation *PyThreadContextEntry::getLocation() {
641   if (!location)
642     return nullptr;
643   return py::cast<PyLocation *>(location);
644 }
645 
646 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
647   auto *tos = getTopOfStack();
648   return tos ? tos->getContext() : nullptr;
649 }
650 
651 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
652   auto *tos = getTopOfStack();
653   return tos ? tos->getInsertionPoint() : nullptr;
654 }
655 
656 PyLocation *PyThreadContextEntry::getDefaultLocation() {
657   auto *tos = getTopOfStack();
658   return tos ? tos->getLocation() : nullptr;
659 }
660 
661 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
662   py::object contextObj = py::cast(context);
663   push(FrameKind::Context, /*context=*/contextObj,
664        /*insertionPoint=*/py::object(),
665        /*location=*/py::object());
666   return contextObj;
667 }
668 
669 void PyThreadContextEntry::popContext(PyMlirContext &context) {
670   auto &stack = getStack();
671   if (stack.empty())
672     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
673   auto &tos = stack.back();
674   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
675     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
676   stack.pop_back();
677 }
678 
679 py::object
680 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
681   py::object contextObj =
682       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
683   py::object insertionPointObj = py::cast(insertionPoint);
684   push(FrameKind::InsertionPoint,
685        /*context=*/contextObj,
686        /*insertionPoint=*/insertionPointObj,
687        /*location=*/py::object());
688   return insertionPointObj;
689 }
690 
691 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
692   auto &stack = getStack();
693   if (stack.empty())
694     throw SetPyError(PyExc_RuntimeError,
695                      "Unbalanced InsertionPoint enter/exit");
696   auto &tos = stack.back();
697   if (tos.frameKind != FrameKind::InsertionPoint &&
698       tos.getInsertionPoint() != &insertionPoint)
699     throw SetPyError(PyExc_RuntimeError,
700                      "Unbalanced InsertionPoint enter/exit");
701   stack.pop_back();
702 }
703 
704 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
705   py::object contextObj = location.getContext().getObject();
706   py::object locationObj = py::cast(location);
707   push(FrameKind::Location, /*context=*/contextObj,
708        /*insertionPoint=*/py::object(),
709        /*location=*/locationObj);
710   return locationObj;
711 }
712 
713 void PyThreadContextEntry::popLocation(PyLocation &location) {
714   auto &stack = getStack();
715   if (stack.empty())
716     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
717   auto &tos = stack.back();
718   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
719     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
720   stack.pop_back();
721 }
722 
723 //------------------------------------------------------------------------------
724 // PyDiagnostic*
725 //------------------------------------------------------------------------------
726 
727 void PyDiagnostic::invalidate() {
728   valid = false;
729   if (materializedNotes) {
730     for (auto &noteObject : *materializedNotes) {
731       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
732       note->invalidate();
733     }
734   }
735 }
736 
737 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
738                                          py::object callback)
739     : context(context), callback(std::move(callback)) {}
740 
741 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
742 
743 void PyDiagnosticHandler::detach() {
744   if (!registeredID)
745     return;
746   MlirDiagnosticHandlerID localID = *registeredID;
747   mlirContextDetachDiagnosticHandler(context, localID);
748   assert(!registeredID && "should have unregistered");
749   // Not strictly necessary but keeps stale pointers from being around to cause
750   // issues.
751   context = {nullptr};
752 }
753 
754 void PyDiagnostic::checkValid() {
755   if (!valid) {
756     throw std::invalid_argument(
757         "Diagnostic is invalid (used outside of callback)");
758   }
759 }
760 
761 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
762   checkValid();
763   return mlirDiagnosticGetSeverity(diagnostic);
764 }
765 
766 PyLocation PyDiagnostic::getLocation() {
767   checkValid();
768   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
769   MlirContext context = mlirLocationGetContext(loc);
770   return PyLocation(PyMlirContext::forContext(context), loc);
771 }
772 
773 py::str PyDiagnostic::getMessage() {
774   checkValid();
775   py::object fileObject = py::module::import("io").attr("StringIO")();
776   PyFileAccumulator accum(fileObject, /*binary=*/false);
777   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
778   return fileObject.attr("getvalue")();
779 }
780 
781 py::tuple PyDiagnostic::getNotes() {
782   checkValid();
783   if (materializedNotes)
784     return *materializedNotes;
785   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
786   materializedNotes = py::tuple(numNotes);
787   for (intptr_t i = 0; i < numNotes; ++i) {
788     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
789     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
790     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
791   }
792   return *materializedNotes;
793 }
794 
795 //------------------------------------------------------------------------------
796 // PyDialect, PyDialectDescriptor, PyDialects
797 //------------------------------------------------------------------------------
798 
799 MlirDialect PyDialects::getDialectForKey(const std::string &key,
800                                          bool attrError) {
801   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
802                                                     {key.data(), key.size()});
803   if (mlirDialectIsNull(dialect)) {
804     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
805                      Twine("Dialect '") + key + "' not found");
806   }
807   return dialect;
808 }
809 
810 //------------------------------------------------------------------------------
811 // PyLocation
812 //------------------------------------------------------------------------------
813 
814 py::object PyLocation::getCapsule() {
815   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
816 }
817 
818 PyLocation PyLocation::createFromCapsule(py::object capsule) {
819   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
820   if (mlirLocationIsNull(rawLoc))
821     throw py::error_already_set();
822   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
823                     rawLoc);
824 }
825 
826 py::object PyLocation::contextEnter() {
827   return PyThreadContextEntry::pushLocation(*this);
828 }
829 
830 void PyLocation::contextExit(const pybind11::object &excType,
831                              const pybind11::object &excVal,
832                              const pybind11::object &excTb) {
833   PyThreadContextEntry::popLocation(*this);
834 }
835 
836 PyLocation &DefaultingPyLocation::resolve() {
837   auto *location = PyThreadContextEntry::getDefaultLocation();
838   if (!location) {
839     throw SetPyError(
840         PyExc_RuntimeError,
841         "An MLIR function requires a Location but none was provided in the "
842         "call or from the surrounding environment. Either pass to the function "
843         "with a 'loc=' argument or establish a default using 'with loc:'");
844   }
845   return *location;
846 }
847 
848 //------------------------------------------------------------------------------
849 // PyModule
850 //------------------------------------------------------------------------------
851 
852 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
853     : BaseContextObject(std::move(contextRef)), module(module) {}
854 
855 PyModule::~PyModule() {
856   py::gil_scoped_acquire acquire;
857   auto &liveModules = getContext()->liveModules;
858   assert(liveModules.count(module.ptr) == 1 &&
859          "destroying module not in live map");
860   liveModules.erase(module.ptr);
861   mlirModuleDestroy(module);
862 }
863 
864 PyModuleRef PyModule::forModule(MlirModule module) {
865   MlirContext context = mlirModuleGetContext(module);
866   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
867 
868   py::gil_scoped_acquire acquire;
869   auto &liveModules = contextRef->liveModules;
870   auto it = liveModules.find(module.ptr);
871   if (it == liveModules.end()) {
872     // Create.
873     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
874     // Note that the default return value policy on cast is automatic_reference,
875     // which does not take ownership (delete will not be called).
876     // Just be explicit.
877     py::object pyRef =
878         py::cast(unownedModule, py::return_value_policy::take_ownership);
879     unownedModule->handle = pyRef;
880     liveModules[module.ptr] =
881         std::make_pair(unownedModule->handle, unownedModule);
882     return PyModuleRef(unownedModule, std::move(pyRef));
883   }
884   // Use existing.
885   PyModule *existing = it->second.second;
886   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
887   return PyModuleRef(existing, std::move(pyRef));
888 }
889 
890 py::object PyModule::createFromCapsule(py::object capsule) {
891   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
892   if (mlirModuleIsNull(rawModule))
893     throw py::error_already_set();
894   return forModule(rawModule).releaseObject();
895 }
896 
897 py::object PyModule::getCapsule() {
898   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
899 }
900 
901 //------------------------------------------------------------------------------
902 // PyOperation
903 //------------------------------------------------------------------------------
904 
905 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
906     : BaseContextObject(std::move(contextRef)), operation(operation) {}
907 
908 PyOperation::~PyOperation() {
909   // If the operation has already been invalidated there is nothing to do.
910   if (!valid)
911     return;
912   auto &liveOperations = getContext()->liveOperations;
913   assert(liveOperations.count(operation.ptr) == 1 &&
914          "destroying operation not in live map");
915   liveOperations.erase(operation.ptr);
916   if (!isAttached()) {
917     mlirOperationDestroy(operation);
918   }
919 }
920 
921 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
922                                            MlirOperation operation,
923                                            py::object parentKeepAlive) {
924   auto &liveOperations = contextRef->liveOperations;
925   // Create.
926   PyOperation *unownedOperation =
927       new PyOperation(std::move(contextRef), operation);
928   // Note that the default return value policy on cast is automatic_reference,
929   // which does not take ownership (delete will not be called).
930   // Just be explicit.
931   py::object pyRef =
932       py::cast(unownedOperation, py::return_value_policy::take_ownership);
933   unownedOperation->handle = pyRef;
934   if (parentKeepAlive) {
935     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
936   }
937   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
938   return PyOperationRef(unownedOperation, std::move(pyRef));
939 }
940 
941 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
942                                          MlirOperation operation,
943                                          py::object parentKeepAlive) {
944   auto &liveOperations = contextRef->liveOperations;
945   auto it = liveOperations.find(operation.ptr);
946   if (it == liveOperations.end()) {
947     // Create.
948     return createInstance(std::move(contextRef), operation,
949                           std::move(parentKeepAlive));
950   }
951   // Use existing.
952   PyOperation *existing = it->second.second;
953   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
954   return PyOperationRef(existing, std::move(pyRef));
955 }
956 
957 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
958                                            MlirOperation operation,
959                                            py::object parentKeepAlive) {
960   auto &liveOperations = contextRef->liveOperations;
961   assert(liveOperations.count(operation.ptr) == 0 &&
962          "cannot create detached operation that already exists");
963   (void)liveOperations;
964 
965   PyOperationRef created = createInstance(std::move(contextRef), operation,
966                                           std::move(parentKeepAlive));
967   created->attached = false;
968   return created;
969 }
970 
971 void PyOperation::checkValid() const {
972   if (!valid) {
973     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
974   }
975 }
976 
977 void PyOperationBase::print(py::object fileObject, bool binary,
978                             llvm::Optional<int64_t> largeElementsLimit,
979                             bool enableDebugInfo, bool prettyDebugInfo,
980                             bool printGenericOpForm, bool useLocalScope,
981                             bool assumeVerified) {
982   PyOperation &operation = getOperation();
983   operation.checkValid();
984   if (fileObject.is_none())
985     fileObject = py::module::import("sys").attr("stdout");
986 
987   if (!assumeVerified && !printGenericOpForm &&
988       !mlirOperationVerify(operation)) {
989     std::string message("// Verification failed, printing generic form\n");
990     if (binary) {
991       fileObject.attr("write")(py::bytes(message));
992     } else {
993       fileObject.attr("write")(py::str(message));
994     }
995     printGenericOpForm = true;
996   }
997 
998   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
999   if (largeElementsLimit)
1000     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1001   if (enableDebugInfo)
1002     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1003   if (printGenericOpForm)
1004     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1005   if (useLocalScope)
1006     mlirOpPrintingFlagsUseLocalScope(flags);
1007 
1008   PyFileAccumulator accum(fileObject, binary);
1009   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1010                               accum.getUserData());
1011   mlirOpPrintingFlagsDestroy(flags);
1012 }
1013 
1014 py::object PyOperationBase::getAsm(bool binary,
1015                                    llvm::Optional<int64_t> largeElementsLimit,
1016                                    bool enableDebugInfo, bool prettyDebugInfo,
1017                                    bool printGenericOpForm, bool useLocalScope,
1018                                    bool assumeVerified) {
1019   py::object fileObject;
1020   if (binary) {
1021     fileObject = py::module::import("io").attr("BytesIO")();
1022   } else {
1023     fileObject = py::module::import("io").attr("StringIO")();
1024   }
1025   print(fileObject, /*binary=*/binary,
1026         /*largeElementsLimit=*/largeElementsLimit,
1027         /*enableDebugInfo=*/enableDebugInfo,
1028         /*prettyDebugInfo=*/prettyDebugInfo,
1029         /*printGenericOpForm=*/printGenericOpForm,
1030         /*useLocalScope=*/useLocalScope,
1031         /*assumeVerified=*/assumeVerified);
1032 
1033   return fileObject.attr("getvalue")();
1034 }
1035 
1036 void PyOperationBase::moveAfter(PyOperationBase &other) {
1037   PyOperation &operation = getOperation();
1038   PyOperation &otherOp = other.getOperation();
1039   operation.checkValid();
1040   otherOp.checkValid();
1041   mlirOperationMoveAfter(operation, otherOp);
1042   operation.parentKeepAlive = otherOp.parentKeepAlive;
1043 }
1044 
1045 void PyOperationBase::moveBefore(PyOperationBase &other) {
1046   PyOperation &operation = getOperation();
1047   PyOperation &otherOp = other.getOperation();
1048   operation.checkValid();
1049   otherOp.checkValid();
1050   mlirOperationMoveBefore(operation, otherOp);
1051   operation.parentKeepAlive = otherOp.parentKeepAlive;
1052 }
1053 
1054 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1055   checkValid();
1056   if (!isAttached())
1057     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1058   MlirOperation operation = mlirOperationGetParentOperation(get());
1059   if (mlirOperationIsNull(operation))
1060     return {};
1061   return PyOperation::forOperation(getContext(), operation);
1062 }
1063 
1064 PyBlock PyOperation::getBlock() {
1065   checkValid();
1066   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1067   MlirBlock block = mlirOperationGetBlock(get());
1068   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1069   assert(parentOperation && "Operation has no parent");
1070   return PyBlock{std::move(*parentOperation), block};
1071 }
1072 
1073 py::object PyOperation::getCapsule() {
1074   checkValid();
1075   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1076 }
1077 
1078 py::object PyOperation::createFromCapsule(py::object capsule) {
1079   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1080   if (mlirOperationIsNull(rawOperation))
1081     throw py::error_already_set();
1082   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1083   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1084       .releaseObject();
1085 }
1086 
1087 static void maybeInsertOperation(PyOperationRef &op,
1088                                  const py::object &maybeIp) {
1089   // InsertPoint active?
1090   if (!maybeIp.is(py::cast(false))) {
1091     PyInsertionPoint *ip;
1092     if (maybeIp.is_none()) {
1093       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1094     } else {
1095       ip = py::cast<PyInsertionPoint *>(maybeIp);
1096     }
1097     if (ip)
1098       ip->insert(*op.get());
1099   }
1100 }
1101 
1102 py::object PyOperation::create(
1103     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1104     llvm::Optional<std::vector<PyValue *>> operands,
1105     llvm::Optional<py::dict> attributes,
1106     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1107     DefaultingPyLocation location, const py::object &maybeIp) {
1108   llvm::SmallVector<MlirValue, 4> mlirOperands;
1109   llvm::SmallVector<MlirType, 4> mlirResults;
1110   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1111   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1112 
1113   // General parameter validation.
1114   if (regions < 0)
1115     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1116 
1117   // Unpack/validate operands.
1118   if (operands) {
1119     mlirOperands.reserve(operands->size());
1120     for (PyValue *operand : *operands) {
1121       if (!operand)
1122         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1123       mlirOperands.push_back(operand->get());
1124     }
1125   }
1126 
1127   // Unpack/validate results.
1128   if (results) {
1129     mlirResults.reserve(results->size());
1130     for (PyType *result : *results) {
1131       // TODO: Verify result type originate from the same context.
1132       if (!result)
1133         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1134       mlirResults.push_back(*result);
1135     }
1136   }
1137   // Unpack/validate attributes.
1138   if (attributes) {
1139     mlirAttributes.reserve(attributes->size());
1140     for (auto &it : *attributes) {
1141       std::string key;
1142       try {
1143         key = it.first.cast<std::string>();
1144       } catch (py::cast_error &err) {
1145         std::string msg = "Invalid attribute key (not a string) when "
1146                           "attempting to create the operation \"" +
1147                           name + "\" (" + err.what() + ")";
1148         throw py::cast_error(msg);
1149       }
1150       try {
1151         auto &attribute = it.second.cast<PyAttribute &>();
1152         // TODO: Verify attribute originates from the same context.
1153         mlirAttributes.emplace_back(std::move(key), attribute);
1154       } catch (py::reference_cast_error &) {
1155         // This exception seems thrown when the value is "None".
1156         std::string msg =
1157             "Found an invalid (`None`?) attribute value for the key \"" + key +
1158             "\" when attempting to create the operation \"" + name + "\"";
1159         throw py::cast_error(msg);
1160       } catch (py::cast_error &err) {
1161         std::string msg = "Invalid attribute value for the key \"" + key +
1162                           "\" when attempting to create the operation \"" +
1163                           name + "\" (" + err.what() + ")";
1164         throw py::cast_error(msg);
1165       }
1166     }
1167   }
1168   // Unpack/validate successors.
1169   if (successors) {
1170     mlirSuccessors.reserve(successors->size());
1171     for (auto *successor : *successors) {
1172       // TODO: Verify successor originate from the same context.
1173       if (!successor)
1174         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1175       mlirSuccessors.push_back(successor->get());
1176     }
1177   }
1178 
1179   // Apply unpacked/validated to the operation state. Beyond this
1180   // point, exceptions cannot be thrown or else the state will leak.
1181   MlirOperationState state =
1182       mlirOperationStateGet(toMlirStringRef(name), location);
1183   if (!mlirOperands.empty())
1184     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1185                                   mlirOperands.data());
1186   if (!mlirResults.empty())
1187     mlirOperationStateAddResults(&state, mlirResults.size(),
1188                                  mlirResults.data());
1189   if (!mlirAttributes.empty()) {
1190     // Note that the attribute names directly reference bytes in
1191     // mlirAttributes, so that vector must not be changed from here
1192     // on.
1193     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1194     mlirNamedAttributes.reserve(mlirAttributes.size());
1195     for (auto &it : mlirAttributes)
1196       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1197           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1198                             toMlirStringRef(it.first)),
1199           it.second));
1200     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1201                                     mlirNamedAttributes.data());
1202   }
1203   if (!mlirSuccessors.empty())
1204     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1205                                     mlirSuccessors.data());
1206   if (regions) {
1207     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1208     mlirRegions.resize(regions);
1209     for (int i = 0; i < regions; ++i)
1210       mlirRegions[i] = mlirRegionCreate();
1211     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1212                                       mlirRegions.data());
1213   }
1214 
1215   // Construct the operation.
1216   MlirOperation operation = mlirOperationCreate(&state);
1217   PyOperationRef created =
1218       PyOperation::createDetached(location->getContext(), operation);
1219   maybeInsertOperation(created, maybeIp);
1220 
1221   return created->createOpView();
1222 }
1223 
1224 py::object PyOperation::clone(const py::object &maybeIp) {
1225   MlirOperation clonedOperation = mlirOperationClone(operation);
1226   PyOperationRef cloned =
1227       PyOperation::createDetached(getContext(), clonedOperation);
1228   maybeInsertOperation(cloned, maybeIp);
1229 
1230   return cloned->createOpView();
1231 }
1232 
1233 py::object PyOperation::createOpView() {
1234   checkValid();
1235   MlirIdentifier ident = mlirOperationGetName(get());
1236   MlirStringRef identStr = mlirIdentifierStr(ident);
1237   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1238       StringRef(identStr.data, identStr.length));
1239   if (opViewClass)
1240     return (*opViewClass)(getRef().getObject());
1241   return py::cast(PyOpView(getRef().getObject()));
1242 }
1243 
1244 void PyOperation::erase() {
1245   checkValid();
1246   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1247   // Python reference to a child operation is live. All children should also
1248   // have their `valid` bit set to false.
1249   auto &liveOperations = getContext()->liveOperations;
1250   if (liveOperations.count(operation.ptr))
1251     liveOperations.erase(operation.ptr);
1252   mlirOperationDestroy(operation);
1253   valid = false;
1254 }
1255 
1256 //------------------------------------------------------------------------------
1257 // PyOpView
1258 //------------------------------------------------------------------------------
1259 
1260 py::object PyOpView::buildGeneric(
1261     const py::object &cls, py::list resultTypeList, py::list operandList,
1262     llvm::Optional<py::dict> attributes,
1263     llvm::Optional<std::vector<PyBlock *>> successors,
1264     llvm::Optional<int> regions, DefaultingPyLocation location,
1265     const py::object &maybeIp) {
1266   PyMlirContextRef context = location->getContext();
1267   // Class level operation construction metadata.
1268   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1269   // Operand and result segment specs are either none, which does no
1270   // variadic unpacking, or a list of ints with segment sizes, where each
1271   // element is either a positive number (typically 1 for a scalar) or -1 to
1272   // indicate that it is derived from the length of the same-indexed operand
1273   // or result (implying that it is a list at that position).
1274   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1275   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1276 
1277   std::vector<uint32_t> operandSegmentLengths;
1278   std::vector<uint32_t> resultSegmentLengths;
1279 
1280   // Validate/determine region count.
1281   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1282   int opMinRegionCount = std::get<0>(opRegionSpec);
1283   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1284   if (!regions) {
1285     regions = opMinRegionCount;
1286   }
1287   if (*regions < opMinRegionCount) {
1288     throw py::value_error(
1289         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1290          llvm::Twine(opMinRegionCount) +
1291          " regions but was built with regions=" + llvm::Twine(*regions))
1292             .str());
1293   }
1294   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1295     throw py::value_error(
1296         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1297          llvm::Twine(opMinRegionCount) +
1298          " regions but was built with regions=" + llvm::Twine(*regions))
1299             .str());
1300   }
1301 
1302   // Unpack results.
1303   std::vector<PyType *> resultTypes;
1304   resultTypes.reserve(resultTypeList.size());
1305   if (resultSegmentSpecObj.is_none()) {
1306     // Non-variadic result unpacking.
1307     for (const auto &it : llvm::enumerate(resultTypeList)) {
1308       try {
1309         resultTypes.push_back(py::cast<PyType *>(it.value()));
1310         if (!resultTypes.back())
1311           throw py::cast_error();
1312       } catch (py::cast_error &err) {
1313         throw py::value_error((llvm::Twine("Result ") +
1314                                llvm::Twine(it.index()) + " of operation \"" +
1315                                name + "\" must be a Type (" + err.what() + ")")
1316                                   .str());
1317       }
1318     }
1319   } else {
1320     // Sized result unpacking.
1321     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1322     if (resultSegmentSpec.size() != resultTypeList.size()) {
1323       throw py::value_error((llvm::Twine("Operation \"") + name +
1324                              "\" requires " +
1325                              llvm::Twine(resultSegmentSpec.size()) +
1326                              " result segments but was provided " +
1327                              llvm::Twine(resultTypeList.size()))
1328                                 .str());
1329     }
1330     resultSegmentLengths.reserve(resultTypeList.size());
1331     for (const auto &it :
1332          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1333       int segmentSpec = std::get<1>(it.value());
1334       if (segmentSpec == 1 || segmentSpec == 0) {
1335         // Unpack unary element.
1336         try {
1337           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1338           if (resultType) {
1339             resultTypes.push_back(resultType);
1340             resultSegmentLengths.push_back(1);
1341           } else if (segmentSpec == 0) {
1342             // Allowed to be optional.
1343             resultSegmentLengths.push_back(0);
1344           } else {
1345             throw py::cast_error("was None and result is not optional");
1346           }
1347         } catch (py::cast_error &err) {
1348           throw py::value_error((llvm::Twine("Result ") +
1349                                  llvm::Twine(it.index()) + " of operation \"" +
1350                                  name + "\" must be a Type (" + err.what() +
1351                                  ")")
1352                                     .str());
1353         }
1354       } else if (segmentSpec == -1) {
1355         // Unpack sequence by appending.
1356         try {
1357           if (std::get<0>(it.value()).is_none()) {
1358             // Treat it as an empty list.
1359             resultSegmentLengths.push_back(0);
1360           } else {
1361             // Unpack the list.
1362             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1363             for (py::object segmentItem : segment) {
1364               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1365               if (!resultTypes.back()) {
1366                 throw py::cast_error("contained a None item");
1367               }
1368             }
1369             resultSegmentLengths.push_back(segment.size());
1370           }
1371         } catch (std::exception &err) {
1372           // NOTE: Sloppy to be using a catch-all here, but there are at least
1373           // three different unrelated exceptions that can be thrown in the
1374           // above "casts". Just keep the scope above small and catch them all.
1375           throw py::value_error((llvm::Twine("Result ") +
1376                                  llvm::Twine(it.index()) + " of operation \"" +
1377                                  name + "\" must be a Sequence of Types (" +
1378                                  err.what() + ")")
1379                                     .str());
1380         }
1381       } else {
1382         throw py::value_error("Unexpected segment spec");
1383       }
1384     }
1385   }
1386 
1387   // Unpack operands.
1388   std::vector<PyValue *> operands;
1389   operands.reserve(operands.size());
1390   if (operandSegmentSpecObj.is_none()) {
1391     // Non-sized operand unpacking.
1392     for (const auto &it : llvm::enumerate(operandList)) {
1393       try {
1394         operands.push_back(py::cast<PyValue *>(it.value()));
1395         if (!operands.back())
1396           throw py::cast_error();
1397       } catch (py::cast_error &err) {
1398         throw py::value_error((llvm::Twine("Operand ") +
1399                                llvm::Twine(it.index()) + " of operation \"" +
1400                                name + "\" must be a Value (" + err.what() + ")")
1401                                   .str());
1402       }
1403     }
1404   } else {
1405     // Sized operand unpacking.
1406     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1407     if (operandSegmentSpec.size() != operandList.size()) {
1408       throw py::value_error((llvm::Twine("Operation \"") + name +
1409                              "\" requires " +
1410                              llvm::Twine(operandSegmentSpec.size()) +
1411                              "operand segments but was provided " +
1412                              llvm::Twine(operandList.size()))
1413                                 .str());
1414     }
1415     operandSegmentLengths.reserve(operandList.size());
1416     for (const auto &it :
1417          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1418       int segmentSpec = std::get<1>(it.value());
1419       if (segmentSpec == 1 || segmentSpec == 0) {
1420         // Unpack unary element.
1421         try {
1422           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1423           if (operandValue) {
1424             operands.push_back(operandValue);
1425             operandSegmentLengths.push_back(1);
1426           } else if (segmentSpec == 0) {
1427             // Allowed to be optional.
1428             operandSegmentLengths.push_back(0);
1429           } else {
1430             throw py::cast_error("was None and operand is not optional");
1431           }
1432         } catch (py::cast_error &err) {
1433           throw py::value_error((llvm::Twine("Operand ") +
1434                                  llvm::Twine(it.index()) + " of operation \"" +
1435                                  name + "\" must be a Value (" + err.what() +
1436                                  ")")
1437                                     .str());
1438         }
1439       } else if (segmentSpec == -1) {
1440         // Unpack sequence by appending.
1441         try {
1442           if (std::get<0>(it.value()).is_none()) {
1443             // Treat it as an empty list.
1444             operandSegmentLengths.push_back(0);
1445           } else {
1446             // Unpack the list.
1447             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1448             for (py::object segmentItem : segment) {
1449               operands.push_back(py::cast<PyValue *>(segmentItem));
1450               if (!operands.back()) {
1451                 throw py::cast_error("contained a None item");
1452               }
1453             }
1454             operandSegmentLengths.push_back(segment.size());
1455           }
1456         } catch (std::exception &err) {
1457           // NOTE: Sloppy to be using a catch-all here, but there are at least
1458           // three different unrelated exceptions that can be thrown in the
1459           // above "casts". Just keep the scope above small and catch them all.
1460           throw py::value_error((llvm::Twine("Operand ") +
1461                                  llvm::Twine(it.index()) + " of operation \"" +
1462                                  name + "\" must be a Sequence of Values (" +
1463                                  err.what() + ")")
1464                                     .str());
1465         }
1466       } else {
1467         throw py::value_error("Unexpected segment spec");
1468       }
1469     }
1470   }
1471 
1472   // Merge operand/result segment lengths into attributes if needed.
1473   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1474     // Dup.
1475     if (attributes) {
1476       attributes = py::dict(*attributes);
1477     } else {
1478       attributes = py::dict();
1479     }
1480     if (attributes->contains("result_segment_sizes") ||
1481         attributes->contains("operand_segment_sizes")) {
1482       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1483                             "'operand_segment_sizes' attribute is unsupported. "
1484                             "Use Operation.create for such low-level access.");
1485     }
1486 
1487     // Add result_segment_sizes attribute.
1488     if (!resultSegmentLengths.empty()) {
1489       int64_t size = resultSegmentLengths.size();
1490       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1491           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1492           resultSegmentLengths.size(), resultSegmentLengths.data());
1493       (*attributes)["result_segment_sizes"] =
1494           PyAttribute(context, segmentLengthAttr);
1495     }
1496 
1497     // Add operand_segment_sizes attribute.
1498     if (!operandSegmentLengths.empty()) {
1499       int64_t size = operandSegmentLengths.size();
1500       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1501           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1502           operandSegmentLengths.size(), operandSegmentLengths.data());
1503       (*attributes)["operand_segment_sizes"] =
1504           PyAttribute(context, segmentLengthAttr);
1505     }
1506   }
1507 
1508   // Delegate to create.
1509   return PyOperation::create(name,
1510                              /*results=*/std::move(resultTypes),
1511                              /*operands=*/std::move(operands),
1512                              /*attributes=*/std::move(attributes),
1513                              /*successors=*/std::move(successors),
1514                              /*regions=*/*regions, location, maybeIp);
1515 }
1516 
1517 PyOpView::PyOpView(const py::object &operationObject)
1518     // Casting through the PyOperationBase base-class and then back to the
1519     // Operation lets us accept any PyOperationBase subclass.
1520     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1521       operationObject(operation.getRef().getObject()) {}
1522 
1523 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1524   // This is... a little gross. The typical pattern is to have a pure python
1525   // class that extends OpView like:
1526   //   class AddFOp(_cext.ir.OpView):
1527   //     def __init__(self, loc, lhs, rhs):
1528   //       operation = loc.context.create_operation(
1529   //           "addf", lhs, rhs, results=[lhs.type])
1530   //       super().__init__(operation)
1531   //
1532   // I.e. The goal of the user facing type is to provide a nice constructor
1533   // that has complete freedom for the op under construction. This is at odds
1534   // with our other desire to sometimes create this object by just passing an
1535   // operation (to initialize the base class). We could do *arg and **kwargs
1536   // munging to try to make it work, but instead, we synthesize a new class
1537   // on the fly which extends this user class (AddFOp in this example) and
1538   // *give it* the base class's __init__ method, thus bypassing the
1539   // intermediate subclass's __init__ method entirely. While slightly,
1540   // underhanded, this is safe/legal because the type hierarchy has not changed
1541   // (we just added a new leaf) and we aren't mucking around with __new__.
1542   // Typically, this new class will be stored on the original as "_Raw" and will
1543   // be used for casts and other things that need a variant of the class that
1544   // is initialized purely from an operation.
1545   py::object parentMetaclass =
1546       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1547   py::dict attributes;
1548   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1549   // now.
1550   //   auto opViewType = py::type::of<PyOpView>();
1551   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1552   attributes["__init__"] = opViewType.attr("__init__");
1553   py::str origName = userClass.attr("__name__");
1554   py::str newName = py::str("_") + origName;
1555   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1556 }
1557 
1558 //------------------------------------------------------------------------------
1559 // PyInsertionPoint.
1560 //------------------------------------------------------------------------------
1561 
1562 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1563 
1564 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1565     : refOperation(beforeOperationBase.getOperation().getRef()),
1566       block((*refOperation)->getBlock()) {}
1567 
1568 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1569   PyOperation &operation = operationBase.getOperation();
1570   if (operation.isAttached())
1571     throw SetPyError(PyExc_ValueError,
1572                      "Attempt to insert operation that is already attached");
1573   block.getParentOperation()->checkValid();
1574   MlirOperation beforeOp = {nullptr};
1575   if (refOperation) {
1576     // Insert before operation.
1577     (*refOperation)->checkValid();
1578     beforeOp = (*refOperation)->get();
1579   } else {
1580     // Insert at end (before null) is only valid if the block does not
1581     // already end in a known terminator (violating this will cause assertion
1582     // failures later).
1583     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1584       throw py::index_error("Cannot insert operation at the end of a block "
1585                             "that already has a terminator. Did you mean to "
1586                             "use 'InsertionPoint.at_block_terminator(block)' "
1587                             "versus 'InsertionPoint(block)'?");
1588     }
1589   }
1590   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1591   operation.setAttached();
1592 }
1593 
1594 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1595   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1596   if (mlirOperationIsNull(firstOp)) {
1597     // Just insert at end.
1598     return PyInsertionPoint(block);
1599   }
1600 
1601   // Insert before first op.
1602   PyOperationRef firstOpRef = PyOperation::forOperation(
1603       block.getParentOperation()->getContext(), firstOp);
1604   return PyInsertionPoint{block, std::move(firstOpRef)};
1605 }
1606 
1607 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1608   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1609   if (mlirOperationIsNull(terminator))
1610     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1611   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1612       block.getParentOperation()->getContext(), terminator);
1613   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1614 }
1615 
1616 py::object PyInsertionPoint::contextEnter() {
1617   return PyThreadContextEntry::pushInsertionPoint(*this);
1618 }
1619 
1620 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1621                                    const pybind11::object &excVal,
1622                                    const pybind11::object &excTb) {
1623   PyThreadContextEntry::popInsertionPoint(*this);
1624 }
1625 
1626 //------------------------------------------------------------------------------
1627 // PyAttribute.
1628 //------------------------------------------------------------------------------
1629 
1630 bool PyAttribute::operator==(const PyAttribute &other) {
1631   return mlirAttributeEqual(attr, other.attr);
1632 }
1633 
1634 py::object PyAttribute::getCapsule() {
1635   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1636 }
1637 
1638 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1639   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1640   if (mlirAttributeIsNull(rawAttr))
1641     throw py::error_already_set();
1642   return PyAttribute(
1643       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1644 }
1645 
1646 //------------------------------------------------------------------------------
1647 // PyNamedAttribute.
1648 //------------------------------------------------------------------------------
1649 
1650 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1651     : ownedName(new std::string(std::move(ownedName))) {
1652   namedAttr = mlirNamedAttributeGet(
1653       mlirIdentifierGet(mlirAttributeGetContext(attr),
1654                         toMlirStringRef(*this->ownedName)),
1655       attr);
1656 }
1657 
1658 //------------------------------------------------------------------------------
1659 // PyType.
1660 //------------------------------------------------------------------------------
1661 
1662 bool PyType::operator==(const PyType &other) {
1663   return mlirTypeEqual(type, other.type);
1664 }
1665 
1666 py::object PyType::getCapsule() {
1667   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1668 }
1669 
1670 PyType PyType::createFromCapsule(py::object capsule) {
1671   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1672   if (mlirTypeIsNull(rawType))
1673     throw py::error_already_set();
1674   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1675                 rawType);
1676 }
1677 
1678 //------------------------------------------------------------------------------
1679 // PyValue and subclases.
1680 //------------------------------------------------------------------------------
1681 
1682 pybind11::object PyValue::getCapsule() {
1683   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1684 }
1685 
1686 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1687   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1688   if (mlirValueIsNull(value))
1689     throw py::error_already_set();
1690   MlirOperation owner;
1691   if (mlirValueIsAOpResult(value))
1692     owner = mlirOpResultGetOwner(value);
1693   if (mlirValueIsABlockArgument(value))
1694     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1695   if (mlirOperationIsNull(owner))
1696     throw py::error_already_set();
1697   MlirContext ctx = mlirOperationGetContext(owner);
1698   PyOperationRef ownerRef =
1699       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1700   return PyValue(ownerRef, value);
1701 }
1702 
1703 //------------------------------------------------------------------------------
1704 // PySymbolTable.
1705 //------------------------------------------------------------------------------
1706 
1707 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1708     : operation(operation.getOperation().getRef()) {
1709   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1710   if (mlirSymbolTableIsNull(symbolTable)) {
1711     throw py::cast_error("Operation is not a Symbol Table.");
1712   }
1713 }
1714 
1715 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1716   operation->checkValid();
1717   MlirOperation symbol = mlirSymbolTableLookup(
1718       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1719   if (mlirOperationIsNull(symbol))
1720     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1721 
1722   return PyOperation::forOperation(operation->getContext(), symbol,
1723                                    operation.getObject())
1724       ->createOpView();
1725 }
1726 
1727 void PySymbolTable::erase(PyOperationBase &symbol) {
1728   operation->checkValid();
1729   symbol.getOperation().checkValid();
1730   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1731   // The operation is also erased, so we must invalidate it. There may be Python
1732   // references to this operation so we don't want to delete it from the list of
1733   // live operations here.
1734   symbol.getOperation().valid = false;
1735 }
1736 
1737 void PySymbolTable::dunderDel(const std::string &name) {
1738   py::object operation = dunderGetItem(name);
1739   erase(py::cast<PyOperationBase &>(operation));
1740 }
1741 
1742 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1743   operation->checkValid();
1744   symbol.getOperation().checkValid();
1745   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1746       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1747   if (mlirAttributeIsNull(symbolAttr))
1748     throw py::value_error("Expected operation to have a symbol name.");
1749   return PyAttribute(
1750       symbol.getOperation().getContext(),
1751       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1752 }
1753 
1754 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1755   // Op must already be a symbol.
1756   PyOperation &operation = symbol.getOperation();
1757   operation.checkValid();
1758   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1759   MlirAttribute existingNameAttr =
1760       mlirOperationGetAttributeByName(operation.get(), attrName);
1761   if (mlirAttributeIsNull(existingNameAttr))
1762     throw py::value_error("Expected operation to have a symbol name.");
1763   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1764 }
1765 
1766 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1767                                   const std::string &name) {
1768   // Op must already be a symbol.
1769   PyOperation &operation = symbol.getOperation();
1770   operation.checkValid();
1771   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1772   MlirAttribute existingNameAttr =
1773       mlirOperationGetAttributeByName(operation.get(), attrName);
1774   if (mlirAttributeIsNull(existingNameAttr))
1775     throw py::value_error("Expected operation to have a symbol name.");
1776   MlirAttribute newNameAttr =
1777       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1778   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1779 }
1780 
1781 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1782   PyOperation &operation = symbol.getOperation();
1783   operation.checkValid();
1784   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1785   MlirAttribute existingVisAttr =
1786       mlirOperationGetAttributeByName(operation.get(), attrName);
1787   if (mlirAttributeIsNull(existingVisAttr))
1788     throw py::value_error("Expected operation to have a symbol visibility.");
1789   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1790 }
1791 
1792 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1793                                   const std::string &visibility) {
1794   if (visibility != "public" && visibility != "private" &&
1795       visibility != "nested")
1796     throw py::value_error(
1797         "Expected visibility to be 'public', 'private' or 'nested'");
1798   PyOperation &operation = symbol.getOperation();
1799   operation.checkValid();
1800   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1801   MlirAttribute existingVisAttr =
1802       mlirOperationGetAttributeByName(operation.get(), attrName);
1803   if (mlirAttributeIsNull(existingVisAttr))
1804     throw py::value_error("Expected operation to have a symbol visibility.");
1805   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1806                                                toMlirStringRef(visibility));
1807   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1808 }
1809 
1810 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1811                                          const std::string &newSymbol,
1812                                          PyOperationBase &from) {
1813   PyOperation &fromOperation = from.getOperation();
1814   fromOperation.checkValid();
1815   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1816           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1817           from.getOperation())))
1818 
1819     throw py::value_error("Symbol rename failed");
1820 }
1821 
1822 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1823                                      bool allSymUsesVisible,
1824                                      py::object callback) {
1825   PyOperation &fromOperation = from.getOperation();
1826   fromOperation.checkValid();
1827   struct UserData {
1828     PyMlirContextRef context;
1829     py::object callback;
1830     bool gotException;
1831     std::string exceptionWhat;
1832     py::object exceptionType;
1833   };
1834   UserData userData{
1835       fromOperation.getContext(), std::move(callback), false, {}, {}};
1836   mlirSymbolTableWalkSymbolTables(
1837       fromOperation.get(), allSymUsesVisible,
1838       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1839         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1840         auto pyFoundOp =
1841             PyOperation::forOperation(calleeUserData->context, foundOp);
1842         if (calleeUserData->gotException)
1843           return;
1844         try {
1845           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1846         } catch (py::error_already_set &e) {
1847           calleeUserData->gotException = true;
1848           calleeUserData->exceptionWhat = e.what();
1849           calleeUserData->exceptionType = e.type();
1850         }
1851       },
1852       static_cast<void *>(&userData));
1853   if (userData.gotException) {
1854     std::string message("Exception raised in callback: ");
1855     message.append(userData.exceptionWhat);
1856     throw std::runtime_error(message);
1857   }
1858 }
1859 
1860 namespace {
1861 /// CRTP base class for Python MLIR values that subclass Value and should be
1862 /// castable from it. The value hierarchy is one level deep and is not supposed
1863 /// to accommodate other levels unless core MLIR changes.
1864 template <typename DerivedTy>
1865 class PyConcreteValue : public PyValue {
1866 public:
1867   // Derived classes must define statics for:
1868   //   IsAFunctionTy isaFunction
1869   //   const char *pyClassName
1870   // and redefine bindDerived.
1871   using ClassTy = py::class_<DerivedTy, PyValue>;
1872   using IsAFunctionTy = bool (*)(MlirValue);
1873 
1874   PyConcreteValue() = default;
1875   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1876       : PyValue(operationRef, value) {}
1877   PyConcreteValue(PyValue &orig)
1878       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1879 
1880   /// Attempts to cast the original value to the derived type and throws on
1881   /// type mismatches.
1882   static MlirValue castFrom(PyValue &orig) {
1883     if (!DerivedTy::isaFunction(orig.get())) {
1884       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1885       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1886                                              DerivedTy::pyClassName +
1887                                              " (from " + origRepr + ")");
1888     }
1889     return orig.get();
1890   }
1891 
1892   /// Binds the Python module objects to functions of this class.
1893   static void bind(py::module &m) {
1894     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1895     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1896     cls.def_static(
1897         "isinstance",
1898         [](PyValue &otherValue) -> bool {
1899           return DerivedTy::isaFunction(otherValue);
1900         },
1901         py::arg("other_value"));
1902     DerivedTy::bindDerived(cls);
1903   }
1904 
1905   /// Implemented by derived classes to add methods to the Python subclass.
1906   static void bindDerived(ClassTy &m) {}
1907 };
1908 
1909 /// Python wrapper for MlirBlockArgument.
1910 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1911 public:
1912   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1913   static constexpr const char *pyClassName = "BlockArgument";
1914   using PyConcreteValue::PyConcreteValue;
1915 
1916   static void bindDerived(ClassTy &c) {
1917     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1918       return PyBlock(self.getParentOperation(),
1919                      mlirBlockArgumentGetOwner(self.get()));
1920     });
1921     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1922       return mlirBlockArgumentGetArgNumber(self.get());
1923     });
1924     c.def(
1925         "set_type",
1926         [](PyBlockArgument &self, PyType type) {
1927           return mlirBlockArgumentSetType(self.get(), type);
1928         },
1929         py::arg("type"));
1930   }
1931 };
1932 
1933 /// Python wrapper for MlirOpResult.
1934 class PyOpResult : public PyConcreteValue<PyOpResult> {
1935 public:
1936   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1937   static constexpr const char *pyClassName = "OpResult";
1938   using PyConcreteValue::PyConcreteValue;
1939 
1940   static void bindDerived(ClassTy &c) {
1941     c.def_property_readonly("owner", [](PyOpResult &self) {
1942       assert(
1943           mlirOperationEqual(self.getParentOperation()->get(),
1944                              mlirOpResultGetOwner(self.get())) &&
1945           "expected the owner of the value in Python to match that in the IR");
1946       return self.getParentOperation().getObject();
1947     });
1948     c.def_property_readonly("result_number", [](PyOpResult &self) {
1949       return mlirOpResultGetResultNumber(self.get());
1950     });
1951   }
1952 };
1953 
1954 /// Returns the list of types of the values held by container.
1955 template <typename Container>
1956 static std::vector<PyType> getValueTypes(Container &container,
1957                                          PyMlirContextRef &context) {
1958   std::vector<PyType> result;
1959   result.reserve(container.getNumElements());
1960   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1961     result.push_back(
1962         PyType(context, mlirValueGetType(container.getElement(i).get())));
1963   }
1964   return result;
1965 }
1966 
1967 /// A list of block arguments. Internally, these are stored as consecutive
1968 /// elements, random access is cheap. The argument list is associated with the
1969 /// operation that contains the block (detached blocks are not allowed in
1970 /// Python bindings) and extends its lifetime.
1971 class PyBlockArgumentList
1972     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1973 public:
1974   static constexpr const char *pyClassName = "BlockArgumentList";
1975 
1976   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1977                       intptr_t startIndex = 0, intptr_t length = -1,
1978                       intptr_t step = 1)
1979       : Sliceable(startIndex,
1980                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1981                   step),
1982         operation(std::move(operation)), block(block) {}
1983 
1984   /// Returns the number of arguments in the list.
1985   intptr_t getNumElements() {
1986     operation->checkValid();
1987     return mlirBlockGetNumArguments(block);
1988   }
1989 
1990   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1991   PyBlockArgument getElement(intptr_t pos) {
1992     MlirValue argument = mlirBlockGetArgument(block, pos);
1993     return PyBlockArgument(operation, argument);
1994   }
1995 
1996   /// Returns a sublist of this list.
1997   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1998                             intptr_t step) {
1999     return PyBlockArgumentList(operation, block, startIndex, length, step);
2000   }
2001 
2002   static void bindDerived(ClassTy &c) {
2003     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2004       return getValueTypes(self, self.operation->getContext());
2005     });
2006   }
2007 
2008 private:
2009   PyOperationRef operation;
2010   MlirBlock block;
2011 };
2012 
2013 /// A list of operation operands. Internally, these are stored as consecutive
2014 /// elements, random access is cheap. The result list is associated with the
2015 /// operation whose results these are, and extends the lifetime of this
2016 /// operation.
2017 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2018 public:
2019   static constexpr const char *pyClassName = "OpOperandList";
2020 
2021   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2022                   intptr_t length = -1, intptr_t step = 1)
2023       : Sliceable(startIndex,
2024                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2025                                : length,
2026                   step),
2027         operation(operation) {}
2028 
2029   intptr_t getNumElements() {
2030     operation->checkValid();
2031     return mlirOperationGetNumOperands(operation->get());
2032   }
2033 
2034   PyValue getElement(intptr_t pos) {
2035     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2036     MlirOperation owner;
2037     if (mlirValueIsAOpResult(operand))
2038       owner = mlirOpResultGetOwner(operand);
2039     else if (mlirValueIsABlockArgument(operand))
2040       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2041     else
2042       assert(false && "Value must be an block arg or op result.");
2043     PyOperationRef pyOwner =
2044         PyOperation::forOperation(operation->getContext(), owner);
2045     return PyValue(pyOwner, operand);
2046   }
2047 
2048   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2049     return PyOpOperandList(operation, startIndex, length, step);
2050   }
2051 
2052   void dunderSetItem(intptr_t index, PyValue value) {
2053     index = wrapIndex(index);
2054     mlirOperationSetOperand(operation->get(), index, value.get());
2055   }
2056 
2057   static void bindDerived(ClassTy &c) {
2058     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2059   }
2060 
2061 private:
2062   PyOperationRef operation;
2063 };
2064 
2065 /// A list of operation results. Internally, these are stored as consecutive
2066 /// elements, random access is cheap. The result list is associated with the
2067 /// operation whose results these are, and extends the lifetime of this
2068 /// operation.
2069 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2070 public:
2071   static constexpr const char *pyClassName = "OpResultList";
2072 
2073   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2074                  intptr_t length = -1, intptr_t step = 1)
2075       : Sliceable(startIndex,
2076                   length == -1 ? mlirOperationGetNumResults(operation->get())
2077                                : length,
2078                   step),
2079         operation(operation) {}
2080 
2081   intptr_t getNumElements() {
2082     operation->checkValid();
2083     return mlirOperationGetNumResults(operation->get());
2084   }
2085 
2086   PyOpResult getElement(intptr_t index) {
2087     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2088     return PyOpResult(value);
2089   }
2090 
2091   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2092     return PyOpResultList(operation, startIndex, length, step);
2093   }
2094 
2095   static void bindDerived(ClassTy &c) {
2096     c.def_property_readonly("types", [](PyOpResultList &self) {
2097       return getValueTypes(self, self.operation->getContext());
2098     });
2099   }
2100 
2101 private:
2102   PyOperationRef operation;
2103 };
2104 
2105 /// A list of operation attributes. Can be indexed by name, producing
2106 /// attributes, or by index, producing named attributes.
2107 class PyOpAttributeMap {
2108 public:
2109   PyOpAttributeMap(PyOperationRef operation)
2110       : operation(std::move(operation)) {}
2111 
2112   PyAttribute dunderGetItemNamed(const std::string &name) {
2113     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2114                                                          toMlirStringRef(name));
2115     if (mlirAttributeIsNull(attr)) {
2116       throw SetPyError(PyExc_KeyError,
2117                        "attempt to access a non-existent attribute");
2118     }
2119     return PyAttribute(operation->getContext(), attr);
2120   }
2121 
2122   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2123     if (index < 0 || index >= dunderLen()) {
2124       throw SetPyError(PyExc_IndexError,
2125                        "attempt to access out of bounds attribute");
2126     }
2127     MlirNamedAttribute namedAttr =
2128         mlirOperationGetAttribute(operation->get(), index);
2129     return PyNamedAttribute(
2130         namedAttr.attribute,
2131         std::string(mlirIdentifierStr(namedAttr.name).data,
2132                     mlirIdentifierStr(namedAttr.name).length));
2133   }
2134 
2135   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2136     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2137                                     attr);
2138   }
2139 
2140   void dunderDelItem(const std::string &name) {
2141     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2142                                                      toMlirStringRef(name));
2143     if (!removed)
2144       throw SetPyError(PyExc_KeyError,
2145                        "attempt to delete a non-existent attribute");
2146   }
2147 
2148   intptr_t dunderLen() {
2149     return mlirOperationGetNumAttributes(operation->get());
2150   }
2151 
2152   bool dunderContains(const std::string &name) {
2153     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2154         operation->get(), toMlirStringRef(name)));
2155   }
2156 
2157   static void bind(py::module &m) {
2158     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2159         .def("__contains__", &PyOpAttributeMap::dunderContains)
2160         .def("__len__", &PyOpAttributeMap::dunderLen)
2161         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2162         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2163         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2164         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2165   }
2166 
2167 private:
2168   PyOperationRef operation;
2169 };
2170 
2171 } // namespace
2172 
2173 //------------------------------------------------------------------------------
2174 // Populates the core exports of the 'ir' submodule.
2175 //------------------------------------------------------------------------------
2176 
2177 void mlir::python::populateIRCore(py::module &m) {
2178   //----------------------------------------------------------------------------
2179   // Enums.
2180   //----------------------------------------------------------------------------
2181   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2182       .value("ERROR", MlirDiagnosticError)
2183       .value("WARNING", MlirDiagnosticWarning)
2184       .value("NOTE", MlirDiagnosticNote)
2185       .value("REMARK", MlirDiagnosticRemark);
2186 
2187   //----------------------------------------------------------------------------
2188   // Mapping of Diagnostics.
2189   //----------------------------------------------------------------------------
2190   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2191       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2192       .def_property_readonly("location", &PyDiagnostic::getLocation)
2193       .def_property_readonly("message", &PyDiagnostic::getMessage)
2194       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2195       .def("__str__", [](PyDiagnostic &self) -> py::str {
2196         if (!self.isValid())
2197           return "<Invalid Diagnostic>";
2198         return self.getMessage();
2199       });
2200 
2201   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2202       .def("detach", &PyDiagnosticHandler::detach)
2203       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2204       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2205       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2206       .def("__exit__", &PyDiagnosticHandler::contextExit);
2207 
2208   //----------------------------------------------------------------------------
2209   // Mapping of MlirContext.
2210   //----------------------------------------------------------------------------
2211   py::class_<PyMlirContext>(m, "Context", py::module_local())
2212       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2213       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2214       .def("_get_context_again",
2215            [](PyMlirContext &self) {
2216              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2217              return ref.releaseObject();
2218            })
2219       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2220       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2221       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2222       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2223                              &PyMlirContext::getCapsule)
2224       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2225       .def("__enter__", &PyMlirContext::contextEnter)
2226       .def("__exit__", &PyMlirContext::contextExit)
2227       .def_property_readonly_static(
2228           "current",
2229           [](py::object & /*class*/) {
2230             auto *context = PyThreadContextEntry::getDefaultContext();
2231             if (!context)
2232               throw SetPyError(PyExc_ValueError, "No current Context");
2233             return context;
2234           },
2235           "Gets the Context bound to the current thread or raises ValueError")
2236       .def_property_readonly(
2237           "dialects",
2238           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2239           "Gets a container for accessing dialects by name")
2240       .def_property_readonly(
2241           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2242           "Alias for 'dialect'")
2243       .def(
2244           "get_dialect_descriptor",
2245           [=](PyMlirContext &self, std::string &name) {
2246             MlirDialect dialect = mlirContextGetOrLoadDialect(
2247                 self.get(), {name.data(), name.size()});
2248             if (mlirDialectIsNull(dialect)) {
2249               throw SetPyError(PyExc_ValueError,
2250                                Twine("Dialect '") + name + "' not found");
2251             }
2252             return PyDialectDescriptor(self.getRef(), dialect);
2253           },
2254           py::arg("dialect_name"),
2255           "Gets or loads a dialect by name, returning its descriptor object")
2256       .def_property(
2257           "allow_unregistered_dialects",
2258           [](PyMlirContext &self) -> bool {
2259             return mlirContextGetAllowUnregisteredDialects(self.get());
2260           },
2261           [](PyMlirContext &self, bool value) {
2262             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2263           })
2264       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2265            py::arg("callback"),
2266            "Attaches a diagnostic handler that will receive callbacks")
2267       .def(
2268           "enable_multithreading",
2269           [](PyMlirContext &self, bool enable) {
2270             mlirContextEnableMultithreading(self.get(), enable);
2271           },
2272           py::arg("enable"))
2273       .def(
2274           "is_registered_operation",
2275           [](PyMlirContext &self, std::string &name) {
2276             return mlirContextIsRegisteredOperation(
2277                 self.get(), MlirStringRef{name.data(), name.size()});
2278           },
2279           py::arg("operation_name"));
2280 
2281   //----------------------------------------------------------------------------
2282   // Mapping of PyDialectDescriptor
2283   //----------------------------------------------------------------------------
2284   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2285       .def_property_readonly("namespace",
2286                              [](PyDialectDescriptor &self) {
2287                                MlirStringRef ns =
2288                                    mlirDialectGetNamespace(self.get());
2289                                return py::str(ns.data, ns.length);
2290                              })
2291       .def("__repr__", [](PyDialectDescriptor &self) {
2292         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2293         std::string repr("<DialectDescriptor ");
2294         repr.append(ns.data, ns.length);
2295         repr.append(">");
2296         return repr;
2297       });
2298 
2299   //----------------------------------------------------------------------------
2300   // Mapping of PyDialects
2301   //----------------------------------------------------------------------------
2302   py::class_<PyDialects>(m, "Dialects", py::module_local())
2303       .def("__getitem__",
2304            [=](PyDialects &self, std::string keyName) {
2305              MlirDialect dialect =
2306                  self.getDialectForKey(keyName, /*attrError=*/false);
2307              py::object descriptor =
2308                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2309              return createCustomDialectWrapper(keyName, std::move(descriptor));
2310            })
2311       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2312         MlirDialect dialect =
2313             self.getDialectForKey(attrName, /*attrError=*/true);
2314         py::object descriptor =
2315             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2316         return createCustomDialectWrapper(attrName, std::move(descriptor));
2317       });
2318 
2319   //----------------------------------------------------------------------------
2320   // Mapping of PyDialect
2321   //----------------------------------------------------------------------------
2322   py::class_<PyDialect>(m, "Dialect", py::module_local())
2323       .def(py::init<py::object>(), py::arg("descriptor"))
2324       .def_property_readonly(
2325           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2326       .def("__repr__", [](py::object self) {
2327         auto clazz = self.attr("__class__");
2328         return py::str("<Dialect ") +
2329                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2330                clazz.attr("__module__") + py::str(".") +
2331                clazz.attr("__name__") + py::str(")>");
2332       });
2333 
2334   //----------------------------------------------------------------------------
2335   // Mapping of Location
2336   //----------------------------------------------------------------------------
2337   py::class_<PyLocation>(m, "Location", py::module_local())
2338       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2339       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2340       .def("__enter__", &PyLocation::contextEnter)
2341       .def("__exit__", &PyLocation::contextExit)
2342       .def("__eq__",
2343            [](PyLocation &self, PyLocation &other) -> bool {
2344              return mlirLocationEqual(self, other);
2345            })
2346       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2347       .def_property_readonly_static(
2348           "current",
2349           [](py::object & /*class*/) {
2350             auto *loc = PyThreadContextEntry::getDefaultLocation();
2351             if (!loc)
2352               throw SetPyError(PyExc_ValueError, "No current Location");
2353             return loc;
2354           },
2355           "Gets the Location bound to the current thread or raises ValueError")
2356       .def_static(
2357           "unknown",
2358           [](DefaultingPyMlirContext context) {
2359             return PyLocation(context->getRef(),
2360                               mlirLocationUnknownGet(context->get()));
2361           },
2362           py::arg("context") = py::none(),
2363           "Gets a Location representing an unknown location")
2364       .def_static(
2365           "callsite",
2366           [](PyLocation callee, const std::vector<PyLocation> &frames,
2367              DefaultingPyMlirContext context) {
2368             if (frames.empty())
2369               throw py::value_error("No caller frames provided");
2370             MlirLocation caller = frames.back().get();
2371             for (const PyLocation &frame :
2372                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2373               caller = mlirLocationCallSiteGet(frame.get(), caller);
2374             return PyLocation(context->getRef(),
2375                               mlirLocationCallSiteGet(callee.get(), caller));
2376           },
2377           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2378           kContextGetCallSiteLocationDocstring)
2379       .def_static(
2380           "file",
2381           [](std::string filename, int line, int col,
2382              DefaultingPyMlirContext context) {
2383             return PyLocation(
2384                 context->getRef(),
2385                 mlirLocationFileLineColGet(
2386                     context->get(), toMlirStringRef(filename), line, col));
2387           },
2388           py::arg("filename"), py::arg("line"), py::arg("col"),
2389           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2390       .def_static(
2391           "fused",
2392           [](const std::vector<PyLocation> &pyLocations,
2393              llvm::Optional<PyAttribute> metadata,
2394              DefaultingPyMlirContext context) {
2395             llvm::SmallVector<MlirLocation, 4> locations;
2396             locations.reserve(pyLocations.size());
2397             for (auto &pyLocation : pyLocations)
2398               locations.push_back(pyLocation.get());
2399             MlirLocation location = mlirLocationFusedGet(
2400                 context->get(), locations.size(), locations.data(),
2401                 metadata ? metadata->get() : MlirAttribute{0});
2402             return PyLocation(context->getRef(), location);
2403           },
2404           py::arg("locations"), py::arg("metadata") = py::none(),
2405           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2406       .def_static(
2407           "name",
2408           [](std::string name, llvm::Optional<PyLocation> childLoc,
2409              DefaultingPyMlirContext context) {
2410             return PyLocation(
2411                 context->getRef(),
2412                 mlirLocationNameGet(
2413                     context->get(), toMlirStringRef(name),
2414                     childLoc ? childLoc->get()
2415                              : mlirLocationUnknownGet(context->get())));
2416           },
2417           py::arg("name"), py::arg("childLoc") = py::none(),
2418           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2419       .def_property_readonly(
2420           "context",
2421           [](PyLocation &self) { return self.getContext().getObject(); },
2422           "Context that owns the Location")
2423       .def(
2424           "emit_error",
2425           [](PyLocation &self, std::string message) {
2426             mlirEmitError(self, message.c_str());
2427           },
2428           py::arg("message"), "Emits an error at this location")
2429       .def("__repr__", [](PyLocation &self) {
2430         PyPrintAccumulator printAccum;
2431         mlirLocationPrint(self, printAccum.getCallback(),
2432                           printAccum.getUserData());
2433         return printAccum.join();
2434       });
2435 
2436   //----------------------------------------------------------------------------
2437   // Mapping of Module
2438   //----------------------------------------------------------------------------
2439   py::class_<PyModule>(m, "Module", py::module_local())
2440       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2441       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2442       .def_static(
2443           "parse",
2444           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2445             MlirModule module = mlirModuleCreateParse(
2446                 context->get(), toMlirStringRef(moduleAsm));
2447             // TODO: Rework error reporting once diagnostic engine is exposed
2448             // in C API.
2449             if (mlirModuleIsNull(module)) {
2450               throw SetPyError(
2451                   PyExc_ValueError,
2452                   "Unable to parse module assembly (see diagnostics)");
2453             }
2454             return PyModule::forModule(module).releaseObject();
2455           },
2456           py::arg("asm"), py::arg("context") = py::none(),
2457           kModuleParseDocstring)
2458       .def_static(
2459           "create",
2460           [](DefaultingPyLocation loc) {
2461             MlirModule module = mlirModuleCreateEmpty(loc);
2462             return PyModule::forModule(module).releaseObject();
2463           },
2464           py::arg("loc") = py::none(), "Creates an empty module")
2465       .def_property_readonly(
2466           "context",
2467           [](PyModule &self) { return self.getContext().getObject(); },
2468           "Context that created the Module")
2469       .def_property_readonly(
2470           "operation",
2471           [](PyModule &self) {
2472             return PyOperation::forOperation(self.getContext(),
2473                                              mlirModuleGetOperation(self.get()),
2474                                              self.getRef().releaseObject())
2475                 .releaseObject();
2476           },
2477           "Accesses the module as an operation")
2478       .def_property_readonly(
2479           "body",
2480           [](PyModule &self) {
2481             PyOperationRef moduleOp = PyOperation::forOperation(
2482                 self.getContext(), mlirModuleGetOperation(self.get()),
2483                 self.getRef().releaseObject());
2484             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2485             return returnBlock;
2486           },
2487           "Return the block for this module")
2488       .def(
2489           "dump",
2490           [](PyModule &self) {
2491             mlirOperationDump(mlirModuleGetOperation(self.get()));
2492           },
2493           kDumpDocstring)
2494       .def(
2495           "__str__",
2496           [](py::object self) {
2497             // Defer to the operation's __str__.
2498             return self.attr("operation").attr("__str__")();
2499           },
2500           kOperationStrDunderDocstring);
2501 
2502   //----------------------------------------------------------------------------
2503   // Mapping of Operation.
2504   //----------------------------------------------------------------------------
2505   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2506       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2507                              [](PyOperationBase &self) {
2508                                return self.getOperation().getCapsule();
2509                              })
2510       .def("__eq__",
2511            [](PyOperationBase &self, PyOperationBase &other) {
2512              return &self.getOperation() == &other.getOperation();
2513            })
2514       .def("__eq__",
2515            [](PyOperationBase &self, py::object other) { return false; })
2516       .def("__hash__",
2517            [](PyOperationBase &self) {
2518              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2519            })
2520       .def_property_readonly("attributes",
2521                              [](PyOperationBase &self) {
2522                                return PyOpAttributeMap(
2523                                    self.getOperation().getRef());
2524                              })
2525       .def_property_readonly("operands",
2526                              [](PyOperationBase &self) {
2527                                return PyOpOperandList(
2528                                    self.getOperation().getRef());
2529                              })
2530       .def_property_readonly("regions",
2531                              [](PyOperationBase &self) {
2532                                return PyRegionList(
2533                                    self.getOperation().getRef());
2534                              })
2535       .def_property_readonly(
2536           "results",
2537           [](PyOperationBase &self) {
2538             return PyOpResultList(self.getOperation().getRef());
2539           },
2540           "Returns the list of Operation results.")
2541       .def_property_readonly(
2542           "result",
2543           [](PyOperationBase &self) {
2544             auto &operation = self.getOperation();
2545             auto numResults = mlirOperationGetNumResults(operation);
2546             if (numResults != 1) {
2547               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2548               throw SetPyError(
2549                   PyExc_ValueError,
2550                   Twine("Cannot call .result on operation ") +
2551                       StringRef(name.data, name.length) + " which has " +
2552                       Twine(numResults) +
2553                       " results (it is only valid for operations with a "
2554                       "single result)");
2555             }
2556             return PyOpResult(operation.getRef(),
2557                               mlirOperationGetResult(operation, 0));
2558           },
2559           "Shortcut to get an op result if it has only one (throws an error "
2560           "otherwise).")
2561       .def_property_readonly(
2562           "location",
2563           [](PyOperationBase &self) {
2564             PyOperation &operation = self.getOperation();
2565             return PyLocation(operation.getContext(),
2566                               mlirOperationGetLocation(operation.get()));
2567           },
2568           "Returns the source location the operation was defined or derived "
2569           "from.")
2570       .def(
2571           "__str__",
2572           [](PyOperationBase &self) {
2573             return self.getAsm(/*binary=*/false,
2574                                /*largeElementsLimit=*/llvm::None,
2575                                /*enableDebugInfo=*/false,
2576                                /*prettyDebugInfo=*/false,
2577                                /*printGenericOpForm=*/false,
2578                                /*useLocalScope=*/false,
2579                                /*assumeVerified=*/false);
2580           },
2581           "Returns the assembly form of the operation.")
2582       .def("print", &PyOperationBase::print,
2583            // Careful: Lots of arguments must match up with print method.
2584            py::arg("file") = py::none(), py::arg("binary") = false,
2585            py::arg("large_elements_limit") = py::none(),
2586            py::arg("enable_debug_info") = false,
2587            py::arg("pretty_debug_info") = false,
2588            py::arg("print_generic_op_form") = false,
2589            py::arg("use_local_scope") = false,
2590            py::arg("assume_verified") = false, kOperationPrintDocstring)
2591       .def("get_asm", &PyOperationBase::getAsm,
2592            // Careful: Lots of arguments must match up with get_asm method.
2593            py::arg("binary") = false,
2594            py::arg("large_elements_limit") = py::none(),
2595            py::arg("enable_debug_info") = false,
2596            py::arg("pretty_debug_info") = false,
2597            py::arg("print_generic_op_form") = false,
2598            py::arg("use_local_scope") = false,
2599            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2600       .def(
2601           "verify",
2602           [](PyOperationBase &self) {
2603             return mlirOperationVerify(self.getOperation());
2604           },
2605           "Verify the operation and return true if it passes, false if it "
2606           "fails.")
2607       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2608            "Puts self immediately after the other operation in its parent "
2609            "block.")
2610       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2611            "Puts self immediately before the other operation in its parent "
2612            "block.")
2613       .def(
2614           "detach_from_parent",
2615           [](PyOperationBase &self) {
2616             PyOperation &operation = self.getOperation();
2617             operation.checkValid();
2618             if (!operation.isAttached())
2619               throw py::value_error("Detached operation has no parent.");
2620 
2621             operation.detachFromParent();
2622             return operation.createOpView();
2623           },
2624           "Detaches the operation from its parent block.");
2625 
2626   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2627       .def_static("create", &PyOperation::create, py::arg("name"),
2628                   py::arg("results") = py::none(),
2629                   py::arg("operands") = py::none(),
2630                   py::arg("attributes") = py::none(),
2631                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2632                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2633                   kOperationCreateDocstring)
2634       .def_property_readonly("parent",
2635                              [](PyOperation &self) -> py::object {
2636                                auto parent = self.getParentOperation();
2637                                if (parent)
2638                                  return parent->getObject();
2639                                return py::none();
2640                              })
2641       .def("erase", &PyOperation::erase)
2642       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2643       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2644                              &PyOperation::getCapsule)
2645       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2646       .def_property_readonly("name",
2647                              [](PyOperation &self) {
2648                                self.checkValid();
2649                                MlirOperation operation = self.get();
2650                                MlirStringRef name = mlirIdentifierStr(
2651                                    mlirOperationGetName(operation));
2652                                return py::str(name.data, name.length);
2653                              })
2654       .def_property_readonly(
2655           "context",
2656           [](PyOperation &self) {
2657             self.checkValid();
2658             return self.getContext().getObject();
2659           },
2660           "Context that owns the Operation")
2661       .def_property_readonly("opview", &PyOperation::createOpView);
2662 
2663   auto opViewClass =
2664       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2665           .def(py::init<py::object>(), py::arg("operation"))
2666           .def_property_readonly("operation", &PyOpView::getOperationObject)
2667           .def_property_readonly(
2668               "context",
2669               [](PyOpView &self) {
2670                 return self.getOperation().getContext().getObject();
2671               },
2672               "Context that owns the Operation")
2673           .def("__str__", [](PyOpView &self) {
2674             return py::str(self.getOperationObject());
2675           });
2676   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2677   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2678   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2679   opViewClass.attr("build_generic") = classmethod(
2680       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2681       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2682       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2683       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2684       "Builds a specific, generated OpView based on class level attributes.");
2685 
2686   //----------------------------------------------------------------------------
2687   // Mapping of PyRegion.
2688   //----------------------------------------------------------------------------
2689   py::class_<PyRegion>(m, "Region", py::module_local())
2690       .def_property_readonly(
2691           "blocks",
2692           [](PyRegion &self) {
2693             return PyBlockList(self.getParentOperation(), self.get());
2694           },
2695           "Returns a forward-optimized sequence of blocks.")
2696       .def_property_readonly(
2697           "owner",
2698           [](PyRegion &self) {
2699             return self.getParentOperation()->createOpView();
2700           },
2701           "Returns the operation owning this region.")
2702       .def(
2703           "__iter__",
2704           [](PyRegion &self) {
2705             self.checkValid();
2706             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2707             return PyBlockIterator(self.getParentOperation(), firstBlock);
2708           },
2709           "Iterates over blocks in the region.")
2710       .def("__eq__",
2711            [](PyRegion &self, PyRegion &other) {
2712              return self.get().ptr == other.get().ptr;
2713            })
2714       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2715 
2716   //----------------------------------------------------------------------------
2717   // Mapping of PyBlock.
2718   //----------------------------------------------------------------------------
2719   py::class_<PyBlock>(m, "Block", py::module_local())
2720       .def_property_readonly(
2721           "owner",
2722           [](PyBlock &self) {
2723             return self.getParentOperation()->createOpView();
2724           },
2725           "Returns the owning operation of this block.")
2726       .def_property_readonly(
2727           "region",
2728           [](PyBlock &self) {
2729             MlirRegion region = mlirBlockGetParentRegion(self.get());
2730             return PyRegion(self.getParentOperation(), region);
2731           },
2732           "Returns the owning region of this block.")
2733       .def_property_readonly(
2734           "arguments",
2735           [](PyBlock &self) {
2736             return PyBlockArgumentList(self.getParentOperation(), self.get());
2737           },
2738           "Returns a list of block arguments.")
2739       .def_property_readonly(
2740           "operations",
2741           [](PyBlock &self) {
2742             return PyOperationList(self.getParentOperation(), self.get());
2743           },
2744           "Returns a forward-optimized sequence of operations.")
2745       .def_static(
2746           "create_at_start",
2747           [](PyRegion &parent, py::list pyArgTypes) {
2748             parent.checkValid();
2749             llvm::SmallVector<MlirType, 4> argTypes;
2750             llvm::SmallVector<MlirLocation, 4> argLocs;
2751             argTypes.reserve(pyArgTypes.size());
2752             argLocs.reserve(pyArgTypes.size());
2753             for (auto &pyArg : pyArgTypes) {
2754               argTypes.push_back(pyArg.cast<PyType &>());
2755               // TODO: Pass in a proper location here.
2756               argLocs.push_back(
2757                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2758             }
2759 
2760             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2761                                               argLocs.data());
2762             mlirRegionInsertOwnedBlock(parent, 0, block);
2763             return PyBlock(parent.getParentOperation(), block);
2764           },
2765           py::arg("parent"), py::arg("arg_types") = py::list(),
2766           "Creates and returns a new Block at the beginning of the given "
2767           "region (with given argument types).")
2768       .def(
2769           "append_to",
2770           [](PyBlock &self, PyRegion &region) {
2771             MlirBlock b = self.get();
2772             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2773               mlirBlockDetach(b);
2774             mlirRegionAppendOwnedBlock(region.get(), b);
2775           },
2776           "Append this block to a region, transferring ownership if necessary")
2777       .def(
2778           "create_before",
2779           [](PyBlock &self, py::args pyArgTypes) {
2780             self.checkValid();
2781             llvm::SmallVector<MlirType, 4> argTypes;
2782             llvm::SmallVector<MlirLocation, 4> argLocs;
2783             argTypes.reserve(pyArgTypes.size());
2784             argLocs.reserve(pyArgTypes.size());
2785             for (auto &pyArg : pyArgTypes) {
2786               argTypes.push_back(pyArg.cast<PyType &>());
2787               // TODO: Pass in a proper location here.
2788               argLocs.push_back(
2789                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2790             }
2791 
2792             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2793                                               argLocs.data());
2794             MlirRegion region = mlirBlockGetParentRegion(self.get());
2795             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2796             return PyBlock(self.getParentOperation(), block);
2797           },
2798           "Creates and returns a new Block before this block "
2799           "(with given argument types).")
2800       .def(
2801           "create_after",
2802           [](PyBlock &self, py::args pyArgTypes) {
2803             self.checkValid();
2804             llvm::SmallVector<MlirType, 4> argTypes;
2805             llvm::SmallVector<MlirLocation, 4> argLocs;
2806             argTypes.reserve(pyArgTypes.size());
2807             argLocs.reserve(pyArgTypes.size());
2808             for (auto &pyArg : pyArgTypes) {
2809               argTypes.push_back(pyArg.cast<PyType &>());
2810 
2811               // TODO: Pass in a proper location here.
2812               argLocs.push_back(
2813                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2814             }
2815             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2816                                               argLocs.data());
2817             MlirRegion region = mlirBlockGetParentRegion(self.get());
2818             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2819             return PyBlock(self.getParentOperation(), block);
2820           },
2821           "Creates and returns a new Block after this block "
2822           "(with given argument types).")
2823       .def(
2824           "__iter__",
2825           [](PyBlock &self) {
2826             self.checkValid();
2827             MlirOperation firstOperation =
2828                 mlirBlockGetFirstOperation(self.get());
2829             return PyOperationIterator(self.getParentOperation(),
2830                                        firstOperation);
2831           },
2832           "Iterates over operations in the block.")
2833       .def("__eq__",
2834            [](PyBlock &self, PyBlock &other) {
2835              return self.get().ptr == other.get().ptr;
2836            })
2837       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2838       .def(
2839           "__str__",
2840           [](PyBlock &self) {
2841             self.checkValid();
2842             PyPrintAccumulator printAccum;
2843             mlirBlockPrint(self.get(), printAccum.getCallback(),
2844                            printAccum.getUserData());
2845             return printAccum.join();
2846           },
2847           "Returns the assembly form of the block.")
2848       .def(
2849           "append",
2850           [](PyBlock &self, PyOperationBase &operation) {
2851             if (operation.getOperation().isAttached())
2852               operation.getOperation().detachFromParent();
2853 
2854             MlirOperation mlirOperation = operation.getOperation().get();
2855             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2856             operation.getOperation().setAttached(
2857                 self.getParentOperation().getObject());
2858           },
2859           py::arg("operation"),
2860           "Appends an operation to this block. If the operation is currently "
2861           "in another block, it will be moved.");
2862 
2863   //----------------------------------------------------------------------------
2864   // Mapping of PyInsertionPoint.
2865   //----------------------------------------------------------------------------
2866 
2867   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2868       .def(py::init<PyBlock &>(), py::arg("block"),
2869            "Inserts after the last operation but still inside the block.")
2870       .def("__enter__", &PyInsertionPoint::contextEnter)
2871       .def("__exit__", &PyInsertionPoint::contextExit)
2872       .def_property_readonly_static(
2873           "current",
2874           [](py::object & /*class*/) {
2875             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2876             if (!ip)
2877               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2878             return ip;
2879           },
2880           "Gets the InsertionPoint bound to the current thread or raises "
2881           "ValueError if none has been set")
2882       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2883            "Inserts before a referenced operation.")
2884       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2885                   py::arg("block"), "Inserts at the beginning of the block.")
2886       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2887                   py::arg("block"), "Inserts before the block terminator.")
2888       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2889            "Inserts an operation.")
2890       .def_property_readonly(
2891           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2892           "Returns the block that this InsertionPoint points to.");
2893 
2894   //----------------------------------------------------------------------------
2895   // Mapping of PyAttribute.
2896   //----------------------------------------------------------------------------
2897   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2898       // Delegate to the PyAttribute copy constructor, which will also lifetime
2899       // extend the backing context which owns the MlirAttribute.
2900       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2901            "Casts the passed attribute to the generic Attribute")
2902       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2903                              &PyAttribute::getCapsule)
2904       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2905       .def_static(
2906           "parse",
2907           [](std::string attrSpec, DefaultingPyMlirContext context) {
2908             MlirAttribute type = mlirAttributeParseGet(
2909                 context->get(), toMlirStringRef(attrSpec));
2910             // TODO: Rework error reporting once diagnostic engine is exposed
2911             // in C API.
2912             if (mlirAttributeIsNull(type)) {
2913               throw SetPyError(PyExc_ValueError,
2914                                Twine("Unable to parse attribute: '") +
2915                                    attrSpec + "'");
2916             }
2917             return PyAttribute(context->getRef(), type);
2918           },
2919           py::arg("asm"), py::arg("context") = py::none(),
2920           "Parses an attribute from an assembly form")
2921       .def_property_readonly(
2922           "context",
2923           [](PyAttribute &self) { return self.getContext().getObject(); },
2924           "Context that owns the Attribute")
2925       .def_property_readonly("type",
2926                              [](PyAttribute &self) {
2927                                return PyType(self.getContext()->getRef(),
2928                                              mlirAttributeGetType(self));
2929                              })
2930       .def(
2931           "get_named",
2932           [](PyAttribute &self, std::string name) {
2933             return PyNamedAttribute(self, std::move(name));
2934           },
2935           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2936       .def("__eq__",
2937            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2938       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2939       .def("__hash__",
2940            [](PyAttribute &self) {
2941              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2942            })
2943       .def(
2944           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2945           kDumpDocstring)
2946       .def(
2947           "__str__",
2948           [](PyAttribute &self) {
2949             PyPrintAccumulator printAccum;
2950             mlirAttributePrint(self, printAccum.getCallback(),
2951                                printAccum.getUserData());
2952             return printAccum.join();
2953           },
2954           "Returns the assembly form of the Attribute.")
2955       .def("__repr__", [](PyAttribute &self) {
2956         // Generally, assembly formats are not printed for __repr__ because
2957         // this can cause exceptionally long debug output and exceptions.
2958         // However, attribute values are generally considered useful and are
2959         // printed. This may need to be re-evaluated if debug dumps end up
2960         // being excessive.
2961         PyPrintAccumulator printAccum;
2962         printAccum.parts.append("Attribute(");
2963         mlirAttributePrint(self, printAccum.getCallback(),
2964                            printAccum.getUserData());
2965         printAccum.parts.append(")");
2966         return printAccum.join();
2967       });
2968 
2969   //----------------------------------------------------------------------------
2970   // Mapping of PyNamedAttribute
2971   //----------------------------------------------------------------------------
2972   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2973       .def("__repr__",
2974            [](PyNamedAttribute &self) {
2975              PyPrintAccumulator printAccum;
2976              printAccum.parts.append("NamedAttribute(");
2977              printAccum.parts.append(
2978                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2979                          mlirIdentifierStr(self.namedAttr.name).length));
2980              printAccum.parts.append("=");
2981              mlirAttributePrint(self.namedAttr.attribute,
2982                                 printAccum.getCallback(),
2983                                 printAccum.getUserData());
2984              printAccum.parts.append(")");
2985              return printAccum.join();
2986            })
2987       .def_property_readonly(
2988           "name",
2989           [](PyNamedAttribute &self) {
2990             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2991                            mlirIdentifierStr(self.namedAttr.name).length);
2992           },
2993           "The name of the NamedAttribute binding")
2994       .def_property_readonly(
2995           "attr",
2996           [](PyNamedAttribute &self) {
2997             // TODO: When named attribute is removed/refactored, also remove
2998             // this constructor (it does an inefficient table lookup).
2999             auto contextRef = PyMlirContext::forContext(
3000                 mlirAttributeGetContext(self.namedAttr.attribute));
3001             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3002           },
3003           py::keep_alive<0, 1>(),
3004           "The underlying generic attribute of the NamedAttribute binding");
3005 
3006   //----------------------------------------------------------------------------
3007   // Mapping of PyType.
3008   //----------------------------------------------------------------------------
3009   py::class_<PyType>(m, "Type", py::module_local())
3010       // Delegate to the PyType copy constructor, which will also lifetime
3011       // extend the backing context which owns the MlirType.
3012       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3013            "Casts the passed type to the generic Type")
3014       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3015       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3016       .def_static(
3017           "parse",
3018           [](std::string typeSpec, DefaultingPyMlirContext context) {
3019             MlirType type =
3020                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3021             // TODO: Rework error reporting once diagnostic engine is exposed
3022             // in C API.
3023             if (mlirTypeIsNull(type)) {
3024               throw SetPyError(PyExc_ValueError,
3025                                Twine("Unable to parse type: '") + typeSpec +
3026                                    "'");
3027             }
3028             return PyType(context->getRef(), type);
3029           },
3030           py::arg("asm"), py::arg("context") = py::none(),
3031           kContextParseTypeDocstring)
3032       .def_property_readonly(
3033           "context", [](PyType &self) { return self.getContext().getObject(); },
3034           "Context that owns the Type")
3035       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3036       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3037       .def("__hash__",
3038            [](PyType &self) {
3039              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3040            })
3041       .def(
3042           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3043       .def(
3044           "__str__",
3045           [](PyType &self) {
3046             PyPrintAccumulator printAccum;
3047             mlirTypePrint(self, printAccum.getCallback(),
3048                           printAccum.getUserData());
3049             return printAccum.join();
3050           },
3051           "Returns the assembly form of the type.")
3052       .def("__repr__", [](PyType &self) {
3053         // Generally, assembly formats are not printed for __repr__ because
3054         // this can cause exceptionally long debug output and exceptions.
3055         // However, types are an exception as they typically have compact
3056         // assembly forms and printing them is useful.
3057         PyPrintAccumulator printAccum;
3058         printAccum.parts.append("Type(");
3059         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3060         printAccum.parts.append(")");
3061         return printAccum.join();
3062       });
3063 
3064   //----------------------------------------------------------------------------
3065   // Mapping of Value.
3066   //----------------------------------------------------------------------------
3067   py::class_<PyValue>(m, "Value", py::module_local())
3068       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3069       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3070       .def_property_readonly(
3071           "context",
3072           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3073           "Context in which the value lives.")
3074       .def(
3075           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3076           kDumpDocstring)
3077       .def_property_readonly(
3078           "owner",
3079           [](PyValue &self) {
3080             assert(mlirOperationEqual(self.getParentOperation()->get(),
3081                                       mlirOpResultGetOwner(self.get())) &&
3082                    "expected the owner of the value in Python to match that in "
3083                    "the IR");
3084             return self.getParentOperation().getObject();
3085           })
3086       .def("__eq__",
3087            [](PyValue &self, PyValue &other) {
3088              return self.get().ptr == other.get().ptr;
3089            })
3090       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3091       .def("__hash__",
3092            [](PyValue &self) {
3093              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3094            })
3095       .def(
3096           "__str__",
3097           [](PyValue &self) {
3098             PyPrintAccumulator printAccum;
3099             printAccum.parts.append("Value(");
3100             mlirValuePrint(self.get(), printAccum.getCallback(),
3101                            printAccum.getUserData());
3102             printAccum.parts.append(")");
3103             return printAccum.join();
3104           },
3105           kValueDunderStrDocstring)
3106       .def_property_readonly("type", [](PyValue &self) {
3107         return PyType(self.getParentOperation()->getContext(),
3108                       mlirValueGetType(self.get()));
3109       });
3110   PyBlockArgument::bind(m);
3111   PyOpResult::bind(m);
3112 
3113   //----------------------------------------------------------------------------
3114   // Mapping of SymbolTable.
3115   //----------------------------------------------------------------------------
3116   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3117       .def(py::init<PyOperationBase &>())
3118       .def("__getitem__", &PySymbolTable::dunderGetItem)
3119       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3120       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3121       .def("__delitem__", &PySymbolTable::dunderDel)
3122       .def("__contains__",
3123            [](PySymbolTable &table, const std::string &name) {
3124              return !mlirOperationIsNull(mlirSymbolTableLookup(
3125                  table, mlirStringRefCreate(name.data(), name.length())));
3126            })
3127       // Static helpers.
3128       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3129                   py::arg("symbol"), py::arg("name"))
3130       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3131                   py::arg("symbol"))
3132       .def_static("get_visibility", &PySymbolTable::getVisibility,
3133                   py::arg("symbol"))
3134       .def_static("set_visibility", &PySymbolTable::setVisibility,
3135                   py::arg("symbol"), py::arg("visibility"))
3136       .def_static("replace_all_symbol_uses",
3137                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3138                   py::arg("new_symbol"), py::arg("from_op"))
3139       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3140                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3141                   py::arg("callback"));
3142 
3143   // Container bindings.
3144   PyBlockArgumentList::bind(m);
3145   PyBlockIterator::bind(m);
3146   PyBlockList::bind(m);
3147   PyOperationIterator::bind(m);
3148   PyOperationList::bind(m);
3149   PyOpAttributeMap::bind(m);
3150   PyOpOperandList::bind(m);
3151   PyOpResultList::bind(m);
3152   PyRegionIterator::bind(m);
3153   PyRegionList::bind(m);
3154 
3155   // Debug bindings.
3156   PyGlobalDebugFlag::bind(m);
3157 }
3158