1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 //#include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 #include <utility>
24 
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32 
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36 
37 static const char kContextParseTypeDocstring[] =
38     R"(Parses the assembly form of a type.
39 
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41 
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44 
45 static const char kContextGetCallSiteLocationDocstring[] =
46     R"(Gets a Location representing a caller and callsite)";
47 
48 static const char kContextGetFileLocationDocstring[] =
49     R"(Gets a Location representing a file, line and column)";
50 
51 static const char kContextGetFusedLocationDocstring[] =
52     R"(Gets a Location representing a fused location with optional metadata)";
53 
54 static const char kContextGetNameLocationDocString[] =
55     R"(Gets a Location representing a named location with optional child location)";
56 
57 static const char kModuleParseDocstring[] =
58     R"(Parses a module's assembly format from a string.
59 
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61 
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64 
65 static const char kOperationCreateDocstring[] =
66     R"(Creates a new operation.
67 
68 Args:
69   name: Operation name (e.g. "dialect.operation").
70   results: Sequence of Type representing op result types.
71   attributes: Dict of str:Attribute.
72   successors: List of Block for the operation's successors.
73   regions: Number of regions to create.
74   location: A Location object (defaults to resolve from context manager).
75   ip: An InsertionPoint (defaults to resolve from context manager or set to
76     False to disable insertion, even with an insertion point set in the
77     context manager).
78 Returns:
79   A new "detached" Operation object. Detached operations can be added
80   to blocks, which causes them to become "attached."
81 )";
82 
83 static const char kOperationPrintDocstring[] =
84     R"(Prints the assembly form of the operation to a file like object.
85 
86 Args:
87   file: The file like object to write to. Defaults to sys.stdout.
88   binary: Whether to write bytes (True) or str (False). Defaults to False.
89   large_elements_limit: Whether to elide elements attributes above this
90     number of elements. Defaults to None (no limit).
91   enable_debug_info: Whether to print debug/location information. Defaults
92     to False.
93   pretty_debug_info: Whether to format debug information for easier reading
94     by a human (warning: the result is unparseable).
95   print_generic_op_form: Whether to print the generic assembly forms of all
96     ops. Defaults to False.
97   use_local_Scope: Whether to print in a way that is more optimized for
98     multi-threaded access but may not be consistent with how the overall
99     module prints.
100   assume_verified: By default, if not printing generic form, the verifier
101     will be run and if it fails, generic form will be printed with a comment
102     about failed verification. While a reasonable default for interactive use,
103     for systematic use, it is often better for the caller to verify explicitly
104     and report failures in a more robust fashion. Set this to True if doing this
105     in order to avoid running a redundant verification. If the IR is actually
106     invalid, behavior is undefined.
107 )";
108 
109 static const char kOperationGetAsmDocstring[] =
110     R"(Gets the assembly form of the operation with all options available.
111 
112 Args:
113   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114     False.
115   ... others ...: See the print() method for common keyword arguments for
116     configuring the printout.
117 Returns:
118   Either a bytes or str object, depending on the setting of the 'binary'
119   argument.
120 )";
121 
122 static const char kOperationStrDunderDocstring[] =
123     R"(Gets the assembly form of the operation with default options.
124 
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129 
130 static const char kDumpDocstring[] =
131     R"(Dumps a debug representation of the object to stderr.)";
132 
133 static const char kAppendBlockDocstring[] =
134     R"(Appends a new block, with argument types as positional args.
135 
136 Returns:
137   The created block.
138 )";
139 
140 static const char kValueDunderStrDocstring[] =
141     R"(Returns the string form of the value.
142 
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147 
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151 
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
154 py::object classmethod(Func f, Args... args) {
155   py::object cf = py::cpp_function(f, args...);
156   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158 
159 static py::object
160 createCustomDialectWrapper(const std::string &dialectNamespace,
161                            py::object dialectDescriptor) {
162   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163   if (!dialectClass) {
164     // Use the base class.
165     return py::cast(PyDialect(std::move(dialectDescriptor)));
166   }
167 
168   // Create the custom implementation.
169   return (*dialectClass)(std::move(dialectDescriptor));
170 }
171 
172 static MlirStringRef toMlirStringRef(const std::string &s) {
173   return mlirStringRefCreate(s.data(), s.size());
174 }
175 
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
178   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179 
180   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181 
182   static void bind(py::module &m) {
183     // Debug flags.
184     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185         .def_property_static("flag", &PyGlobalDebugFlag::get,
186                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187   }
188 };
189 
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193 
194 namespace {
195 
196 class PyRegionIterator {
197 public:
198   PyRegionIterator(PyOperationRef operation)
199       : operation(std::move(operation)) {}
200 
201   PyRegionIterator &dunderIter() { return *this; }
202 
203   PyRegion dunderNext() {
204     operation->checkValid();
205     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206       throw py::stop_iteration();
207     }
208     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209     return PyRegion(operation, region);
210   }
211 
212   static void bind(py::module &m) {
213     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214         .def("__iter__", &PyRegionIterator::dunderIter)
215         .def("__next__", &PyRegionIterator::dunderNext);
216   }
217 
218 private:
219   PyOperationRef operation;
220   int nextIndex = 0;
221 };
222 
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
227   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228 
229   intptr_t dunderLen() {
230     operation->checkValid();
231     return mlirOperationGetNumRegions(operation->get());
232   }
233 
234   PyRegion dunderGetItem(intptr_t index) {
235     // dunderLen checks validity.
236     if (index < 0 || index >= dunderLen()) {
237       throw SetPyError(PyExc_IndexError,
238                        "attempt to access out of bounds region");
239     }
240     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241     return PyRegion(operation, region);
242   }
243 
244   static void bind(py::module &m) {
245     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246         .def("__len__", &PyRegionList::dunderLen)
247         .def("__getitem__", &PyRegionList::dunderGetItem);
248   }
249 
250 private:
251   PyOperationRef operation;
252 };
253 
254 class PyBlockIterator {
255 public:
256   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257       : operation(std::move(operation)), next(next) {}
258 
259   PyBlockIterator &dunderIter() { return *this; }
260 
261   PyBlock dunderNext() {
262     operation->checkValid();
263     if (mlirBlockIsNull(next)) {
264       throw py::stop_iteration();
265     }
266 
267     PyBlock returnBlock(operation, next);
268     next = mlirBlockGetNextInRegion(next);
269     return returnBlock;
270   }
271 
272   static void bind(py::module &m) {
273     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274         .def("__iter__", &PyBlockIterator::dunderIter)
275         .def("__next__", &PyBlockIterator::dunderNext);
276   }
277 
278 private:
279   PyOperationRef operation;
280   MlirBlock next;
281 };
282 
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
288   PyBlockList(PyOperationRef operation, MlirRegion region)
289       : operation(std::move(operation)), region(region) {}
290 
291   PyBlockIterator dunderIter() {
292     operation->checkValid();
293     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294   }
295 
296   intptr_t dunderLen() {
297     operation->checkValid();
298     intptr_t count = 0;
299     MlirBlock block = mlirRegionGetFirstBlock(region);
300     while (!mlirBlockIsNull(block)) {
301       count += 1;
302       block = mlirBlockGetNextInRegion(block);
303     }
304     return count;
305   }
306 
307   PyBlock dunderGetItem(intptr_t index) {
308     operation->checkValid();
309     if (index < 0) {
310       throw SetPyError(PyExc_IndexError,
311                        "attempt to access out of bounds block");
312     }
313     MlirBlock block = mlirRegionGetFirstBlock(region);
314     while (!mlirBlockIsNull(block)) {
315       if (index == 0) {
316         return PyBlock(operation, block);
317       }
318       block = mlirBlockGetNextInRegion(block);
319       index -= 1;
320     }
321     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322   }
323 
324   PyBlock appendBlock(const py::args &pyArgTypes) {
325     operation->checkValid();
326     llvm::SmallVector<MlirType, 4> argTypes;
327     llvm::SmallVector<MlirLocation, 4> argLocs;
328     argTypes.reserve(pyArgTypes.size());
329     argLocs.reserve(pyArgTypes.size());
330     for (auto &pyArg : pyArgTypes) {
331       argTypes.push_back(pyArg.cast<PyType &>());
332       // TODO: Pass in a proper location here.
333       argLocs.push_back(
334           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335     }
336 
337     MlirBlock block =
338         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339     mlirRegionAppendOwnedBlock(region, block);
340     return PyBlock(operation, block);
341   }
342 
343   static void bind(py::module &m) {
344     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345         .def("__getitem__", &PyBlockList::dunderGetItem)
346         .def("__iter__", &PyBlockList::dunderIter)
347         .def("__len__", &PyBlockList::dunderLen)
348         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349   }
350 
351 private:
352   PyOperationRef operation;
353   MlirRegion region;
354 };
355 
356 class PyOperationIterator {
357 public:
358   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359       : parentOperation(std::move(parentOperation)), next(next) {}
360 
361   PyOperationIterator &dunderIter() { return *this; }
362 
363   py::object dunderNext() {
364     parentOperation->checkValid();
365     if (mlirOperationIsNull(next)) {
366       throw py::stop_iteration();
367     }
368 
369     PyOperationRef returnOperation =
370         PyOperation::forOperation(parentOperation->getContext(), next);
371     next = mlirOperationGetNextInBlock(next);
372     return returnOperation->createOpView();
373   }
374 
375   static void bind(py::module &m) {
376     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377         .def("__iter__", &PyOperationIterator::dunderIter)
378         .def("__next__", &PyOperationIterator::dunderNext);
379   }
380 
381 private:
382   PyOperationRef parentOperation;
383   MlirOperation next;
384 };
385 
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
392   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393       : parentOperation(std::move(parentOperation)), block(block) {}
394 
395   PyOperationIterator dunderIter() {
396     parentOperation->checkValid();
397     return PyOperationIterator(parentOperation,
398                                mlirBlockGetFirstOperation(block));
399   }
400 
401   intptr_t dunderLen() {
402     parentOperation->checkValid();
403     intptr_t count = 0;
404     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405     while (!mlirOperationIsNull(childOp)) {
406       count += 1;
407       childOp = mlirOperationGetNextInBlock(childOp);
408     }
409     return count;
410   }
411 
412   py::object dunderGetItem(intptr_t index) {
413     parentOperation->checkValid();
414     if (index < 0) {
415       throw SetPyError(PyExc_IndexError,
416                        "attempt to access out of bounds operation");
417     }
418     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419     while (!mlirOperationIsNull(childOp)) {
420       if (index == 0) {
421         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422             ->createOpView();
423       }
424       childOp = mlirOperationGetNextInBlock(childOp);
425       index -= 1;
426     }
427     throw SetPyError(PyExc_IndexError,
428                      "attempt to access out of bounds operation");
429   }
430 
431   static void bind(py::module &m) {
432     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433         .def("__getitem__", &PyOperationList::dunderGetItem)
434         .def("__iter__", &PyOperationList::dunderIter)
435         .def("__len__", &PyOperationList::dunderLen);
436   }
437 
438 private:
439   PyOperationRef parentOperation;
440   MlirBlock block;
441 };
442 
443 } // namespace
444 
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448 
449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450   py::gil_scoped_acquire acquire;
451   auto &liveContexts = getLiveContexts();
452   liveContexts[context.ptr] = this;
453 }
454 
455 PyMlirContext::~PyMlirContext() {
456   // Note that the only public way to construct an instance is via the
457   // forContext method, which always puts the associated handle into
458   // liveContexts.
459   py::gil_scoped_acquire acquire;
460   getLiveContexts().erase(context.ptr);
461   mlirContextDestroy(context);
462 }
463 
464 py::object PyMlirContext::getCapsule() {
465   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467 
468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470   if (mlirContextIsNull(rawContext))
471     throw py::error_already_set();
472   return forContext(rawContext).releaseObject();
473 }
474 
475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476   MlirContext context = mlirContextCreate();
477   return new PyMlirContext(context);
478 }
479 
480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
481   py::gil_scoped_acquire acquire;
482   auto &liveContexts = getLiveContexts();
483   auto it = liveContexts.find(context.ptr);
484   if (it == liveContexts.end()) {
485     // Create.
486     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
487     py::object pyRef = py::cast(unownedContextWrapper);
488     assert(pyRef && "cast to py::object failed");
489     liveContexts[context.ptr] = unownedContextWrapper;
490     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
491   }
492   // Use existing.
493   py::object pyRef = py::cast(it->second);
494   return PyMlirContextRef(it->second, std::move(pyRef));
495 }
496 
497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
498   static LiveContextMap liveContexts;
499   return liveContexts;
500 }
501 
502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
503 
504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
505 
506 size_t PyMlirContext::clearLiveOperations() {
507   for (auto &op : liveOperations)
508     op.second.second->setInvalid();
509   size_t numInvalidated = liveOperations.size();
510   liveOperations.clear();
511   return numInvalidated;
512 }
513 
514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
515 
516 pybind11::object PyMlirContext::contextEnter() {
517   return PyThreadContextEntry::pushContext(*this);
518 }
519 
520 void PyMlirContext::contextExit(const pybind11::object &excType,
521                                 const pybind11::object &excVal,
522                                 const pybind11::object &excTb) {
523   PyThreadContextEntry::popContext(*this);
524 }
525 
526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
527   // Note that ownership is transferred to the delete callback below by way of
528   // an explicit inc_ref (borrow).
529   PyDiagnosticHandler *pyHandler =
530       new PyDiagnosticHandler(get(), std::move(callback));
531   py::object pyHandlerObject =
532       py::cast(pyHandler, py::return_value_policy::take_ownership);
533   pyHandlerObject.inc_ref();
534 
535   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
536   // guaranteed to be known to pybind.
537   auto handlerCallback =
538       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
539     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
540     py::object pyDiagnosticObject =
541         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
542 
543     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
544     bool result = false;
545     {
546       // Since this can be called from arbitrary C++ contexts, always get the
547       // gil.
548       py::gil_scoped_acquire gil;
549       try {
550         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
551       } catch (std::exception &e) {
552         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
553                 e.what());
554         pyHandler->hadError = true;
555       }
556     }
557 
558     pyDiagnostic->invalidate();
559     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
560   };
561   auto deleteCallback = +[](void *userData) {
562     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
563     assert(pyHandler->registeredID && "handler is not registered");
564     pyHandler->registeredID.reset();
565 
566     // Decrement reference, balancing the inc_ref() above.
567     py::object pyHandlerObject =
568         py::cast(pyHandler, py::return_value_policy::reference);
569     pyHandlerObject.dec_ref();
570   };
571 
572   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
573       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
574   return pyHandlerObject;
575 }
576 
577 PyMlirContext &DefaultingPyMlirContext::resolve() {
578   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
579   if (!context) {
580     throw SetPyError(
581         PyExc_RuntimeError,
582         "An MLIR function requires a Context but none was provided in the call "
583         "or from the surrounding environment. Either pass to the function with "
584         "a 'context=' argument or establish a default using 'with Context():'");
585   }
586   return *context;
587 }
588 
589 //------------------------------------------------------------------------------
590 // PyThreadContextEntry management
591 //------------------------------------------------------------------------------
592 
593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
594   static thread_local std::vector<PyThreadContextEntry> stack;
595   return stack;
596 }
597 
598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
599   auto &stack = getStack();
600   if (stack.empty())
601     return nullptr;
602   return &stack.back();
603 }
604 
605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
606                                 py::object insertionPoint,
607                                 py::object location) {
608   auto &stack = getStack();
609   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
610                      std::move(location));
611   // If the new stack has more than one entry and the context of the new top
612   // entry matches the previous, copy the insertionPoint and location from the
613   // previous entry if missing from the new top entry.
614   if (stack.size() > 1) {
615     auto &prev = *(stack.rbegin() + 1);
616     auto &current = stack.back();
617     if (current.context.is(prev.context)) {
618       // Default non-context objects from the previous entry.
619       if (!current.insertionPoint)
620         current.insertionPoint = prev.insertionPoint;
621       if (!current.location)
622         current.location = prev.location;
623     }
624   }
625 }
626 
627 PyMlirContext *PyThreadContextEntry::getContext() {
628   if (!context)
629     return nullptr;
630   return py::cast<PyMlirContext *>(context);
631 }
632 
633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
634   if (!insertionPoint)
635     return nullptr;
636   return py::cast<PyInsertionPoint *>(insertionPoint);
637 }
638 
639 PyLocation *PyThreadContextEntry::getLocation() {
640   if (!location)
641     return nullptr;
642   return py::cast<PyLocation *>(location);
643 }
644 
645 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
646   auto *tos = getTopOfStack();
647   return tos ? tos->getContext() : nullptr;
648 }
649 
650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
651   auto *tos = getTopOfStack();
652   return tos ? tos->getInsertionPoint() : nullptr;
653 }
654 
655 PyLocation *PyThreadContextEntry::getDefaultLocation() {
656   auto *tos = getTopOfStack();
657   return tos ? tos->getLocation() : nullptr;
658 }
659 
660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
661   py::object contextObj = py::cast(context);
662   push(FrameKind::Context, /*context=*/contextObj,
663        /*insertionPoint=*/py::object(),
664        /*location=*/py::object());
665   return contextObj;
666 }
667 
668 void PyThreadContextEntry::popContext(PyMlirContext &context) {
669   auto &stack = getStack();
670   if (stack.empty())
671     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
672   auto &tos = stack.back();
673   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
674     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
675   stack.pop_back();
676 }
677 
678 py::object
679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
680   py::object contextObj =
681       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
682   py::object insertionPointObj = py::cast(insertionPoint);
683   push(FrameKind::InsertionPoint,
684        /*context=*/contextObj,
685        /*insertionPoint=*/insertionPointObj,
686        /*location=*/py::object());
687   return insertionPointObj;
688 }
689 
690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
691   auto &stack = getStack();
692   if (stack.empty())
693     throw SetPyError(PyExc_RuntimeError,
694                      "Unbalanced InsertionPoint enter/exit");
695   auto &tos = stack.back();
696   if (tos.frameKind != FrameKind::InsertionPoint &&
697       tos.getInsertionPoint() != &insertionPoint)
698     throw SetPyError(PyExc_RuntimeError,
699                      "Unbalanced InsertionPoint enter/exit");
700   stack.pop_back();
701 }
702 
703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
704   py::object contextObj = location.getContext().getObject();
705   py::object locationObj = py::cast(location);
706   push(FrameKind::Location, /*context=*/contextObj,
707        /*insertionPoint=*/py::object(),
708        /*location=*/locationObj);
709   return locationObj;
710 }
711 
712 void PyThreadContextEntry::popLocation(PyLocation &location) {
713   auto &stack = getStack();
714   if (stack.empty())
715     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
716   auto &tos = stack.back();
717   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
718     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
719   stack.pop_back();
720 }
721 
722 //------------------------------------------------------------------------------
723 // PyDiagnostic*
724 //------------------------------------------------------------------------------
725 
726 void PyDiagnostic::invalidate() {
727   valid = false;
728   if (materializedNotes) {
729     for (auto &noteObject : *materializedNotes) {
730       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
731       note->invalidate();
732     }
733   }
734 }
735 
736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
737                                          py::object callback)
738     : context(context), callback(std::move(callback)) {}
739 
740 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
741 
742 void PyDiagnosticHandler::detach() {
743   if (!registeredID)
744     return;
745   MlirDiagnosticHandlerID localID = *registeredID;
746   mlirContextDetachDiagnosticHandler(context, localID);
747   assert(!registeredID && "should have unregistered");
748   // Not strictly necessary but keeps stale pointers from being around to cause
749   // issues.
750   context = {nullptr};
751 }
752 
753 void PyDiagnostic::checkValid() {
754   if (!valid) {
755     throw std::invalid_argument(
756         "Diagnostic is invalid (used outside of callback)");
757   }
758 }
759 
760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
761   checkValid();
762   return mlirDiagnosticGetSeverity(diagnostic);
763 }
764 
765 PyLocation PyDiagnostic::getLocation() {
766   checkValid();
767   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
768   MlirContext context = mlirLocationGetContext(loc);
769   return PyLocation(PyMlirContext::forContext(context), loc);
770 }
771 
772 py::str PyDiagnostic::getMessage() {
773   checkValid();
774   py::object fileObject = py::module::import("io").attr("StringIO")();
775   PyFileAccumulator accum(fileObject, /*binary=*/false);
776   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
777   return fileObject.attr("getvalue")();
778 }
779 
780 py::tuple PyDiagnostic::getNotes() {
781   checkValid();
782   if (materializedNotes)
783     return *materializedNotes;
784   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
785   materializedNotes = py::tuple(numNotes);
786   for (intptr_t i = 0; i < numNotes; ++i) {
787     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
788     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
789     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
790   }
791   return *materializedNotes;
792 }
793 
794 //------------------------------------------------------------------------------
795 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
796 //------------------------------------------------------------------------------
797 
798 MlirDialect PyDialects::getDialectForKey(const std::string &key,
799                                          bool attrError) {
800   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
801                                                     {key.data(), key.size()});
802   if (mlirDialectIsNull(dialect)) {
803     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
804                      Twine("Dialect '") + key + "' not found");
805   }
806   return dialect;
807 }
808 
809 py::object PyDialectRegistry::getCapsule() {
810   return py::reinterpret_steal<py::object>(
811       mlirPythonDialectRegistryToCapsule(*this));
812 }
813 
814 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
815   MlirDialectRegistry rawRegistry =
816       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
817   if (mlirDialectRegistryIsNull(rawRegistry))
818     throw py::error_already_set();
819   return PyDialectRegistry(rawRegistry);
820 }
821 
822 //------------------------------------------------------------------------------
823 // PyLocation
824 //------------------------------------------------------------------------------
825 
826 py::object PyLocation::getCapsule() {
827   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
828 }
829 
830 PyLocation PyLocation::createFromCapsule(py::object capsule) {
831   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
832   if (mlirLocationIsNull(rawLoc))
833     throw py::error_already_set();
834   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
835                     rawLoc);
836 }
837 
838 py::object PyLocation::contextEnter() {
839   return PyThreadContextEntry::pushLocation(*this);
840 }
841 
842 void PyLocation::contextExit(const pybind11::object &excType,
843                              const pybind11::object &excVal,
844                              const pybind11::object &excTb) {
845   PyThreadContextEntry::popLocation(*this);
846 }
847 
848 PyLocation &DefaultingPyLocation::resolve() {
849   auto *location = PyThreadContextEntry::getDefaultLocation();
850   if (!location) {
851     throw SetPyError(
852         PyExc_RuntimeError,
853         "An MLIR function requires a Location but none was provided in the "
854         "call or from the surrounding environment. Either pass to the function "
855         "with a 'loc=' argument or establish a default using 'with loc:'");
856   }
857   return *location;
858 }
859 
860 //------------------------------------------------------------------------------
861 // PyModule
862 //------------------------------------------------------------------------------
863 
864 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
865     : BaseContextObject(std::move(contextRef)), module(module) {}
866 
867 PyModule::~PyModule() {
868   py::gil_scoped_acquire acquire;
869   auto &liveModules = getContext()->liveModules;
870   assert(liveModules.count(module.ptr) == 1 &&
871          "destroying module not in live map");
872   liveModules.erase(module.ptr);
873   mlirModuleDestroy(module);
874 }
875 
876 PyModuleRef PyModule::forModule(MlirModule module) {
877   MlirContext context = mlirModuleGetContext(module);
878   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
879 
880   py::gil_scoped_acquire acquire;
881   auto &liveModules = contextRef->liveModules;
882   auto it = liveModules.find(module.ptr);
883   if (it == liveModules.end()) {
884     // Create.
885     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
886     // Note that the default return value policy on cast is automatic_reference,
887     // which does not take ownership (delete will not be called).
888     // Just be explicit.
889     py::object pyRef =
890         py::cast(unownedModule, py::return_value_policy::take_ownership);
891     unownedModule->handle = pyRef;
892     liveModules[module.ptr] =
893         std::make_pair(unownedModule->handle, unownedModule);
894     return PyModuleRef(unownedModule, std::move(pyRef));
895   }
896   // Use existing.
897   PyModule *existing = it->second.second;
898   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
899   return PyModuleRef(existing, std::move(pyRef));
900 }
901 
902 py::object PyModule::createFromCapsule(py::object capsule) {
903   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
904   if (mlirModuleIsNull(rawModule))
905     throw py::error_already_set();
906   return forModule(rawModule).releaseObject();
907 }
908 
909 py::object PyModule::getCapsule() {
910   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
911 }
912 
913 //------------------------------------------------------------------------------
914 // PyOperation
915 //------------------------------------------------------------------------------
916 
917 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
918     : BaseContextObject(std::move(contextRef)), operation(operation) {}
919 
920 PyOperation::~PyOperation() {
921   // If the operation has already been invalidated there is nothing to do.
922   if (!valid)
923     return;
924   auto &liveOperations = getContext()->liveOperations;
925   assert(liveOperations.count(operation.ptr) == 1 &&
926          "destroying operation not in live map");
927   liveOperations.erase(operation.ptr);
928   if (!isAttached()) {
929     mlirOperationDestroy(operation);
930   }
931 }
932 
933 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
934                                            MlirOperation operation,
935                                            py::object parentKeepAlive) {
936   auto &liveOperations = contextRef->liveOperations;
937   // Create.
938   PyOperation *unownedOperation =
939       new PyOperation(std::move(contextRef), operation);
940   // Note that the default return value policy on cast is automatic_reference,
941   // which does not take ownership (delete will not be called).
942   // Just be explicit.
943   py::object pyRef =
944       py::cast(unownedOperation, py::return_value_policy::take_ownership);
945   unownedOperation->handle = pyRef;
946   if (parentKeepAlive) {
947     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
948   }
949   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
950   return PyOperationRef(unownedOperation, std::move(pyRef));
951 }
952 
953 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
954                                          MlirOperation operation,
955                                          py::object parentKeepAlive) {
956   auto &liveOperations = contextRef->liveOperations;
957   auto it = liveOperations.find(operation.ptr);
958   if (it == liveOperations.end()) {
959     // Create.
960     return createInstance(std::move(contextRef), operation,
961                           std::move(parentKeepAlive));
962   }
963   // Use existing.
964   PyOperation *existing = it->second.second;
965   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
966   return PyOperationRef(existing, std::move(pyRef));
967 }
968 
969 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
970                                            MlirOperation operation,
971                                            py::object parentKeepAlive) {
972   auto &liveOperations = contextRef->liveOperations;
973   assert(liveOperations.count(operation.ptr) == 0 &&
974          "cannot create detached operation that already exists");
975   (void)liveOperations;
976 
977   PyOperationRef created = createInstance(std::move(contextRef), operation,
978                                           std::move(parentKeepAlive));
979   created->attached = false;
980   return created;
981 }
982 
983 void PyOperation::checkValid() const {
984   if (!valid) {
985     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
986   }
987 }
988 
989 void PyOperationBase::print(py::object fileObject, bool binary,
990                             llvm::Optional<int64_t> largeElementsLimit,
991                             bool enableDebugInfo, bool prettyDebugInfo,
992                             bool printGenericOpForm, bool useLocalScope,
993                             bool assumeVerified) {
994   PyOperation &operation = getOperation();
995   operation.checkValid();
996   if (fileObject.is_none())
997     fileObject = py::module::import("sys").attr("stdout");
998 
999   if (!assumeVerified && !printGenericOpForm &&
1000       !mlirOperationVerify(operation)) {
1001     std::string message("// Verification failed, printing generic form\n");
1002     if (binary) {
1003       fileObject.attr("write")(py::bytes(message));
1004     } else {
1005       fileObject.attr("write")(py::str(message));
1006     }
1007     printGenericOpForm = true;
1008   }
1009 
1010   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1011   if (largeElementsLimit)
1012     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1013   if (enableDebugInfo)
1014     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1015   if (printGenericOpForm)
1016     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1017   if (useLocalScope)
1018     mlirOpPrintingFlagsUseLocalScope(flags);
1019 
1020   PyFileAccumulator accum(fileObject, binary);
1021   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1022                               accum.getUserData());
1023   mlirOpPrintingFlagsDestroy(flags);
1024 }
1025 
1026 py::object PyOperationBase::getAsm(bool binary,
1027                                    llvm::Optional<int64_t> largeElementsLimit,
1028                                    bool enableDebugInfo, bool prettyDebugInfo,
1029                                    bool printGenericOpForm, bool useLocalScope,
1030                                    bool assumeVerified) {
1031   py::object fileObject;
1032   if (binary) {
1033     fileObject = py::module::import("io").attr("BytesIO")();
1034   } else {
1035     fileObject = py::module::import("io").attr("StringIO")();
1036   }
1037   print(fileObject, /*binary=*/binary,
1038         /*largeElementsLimit=*/largeElementsLimit,
1039         /*enableDebugInfo=*/enableDebugInfo,
1040         /*prettyDebugInfo=*/prettyDebugInfo,
1041         /*printGenericOpForm=*/printGenericOpForm,
1042         /*useLocalScope=*/useLocalScope,
1043         /*assumeVerified=*/assumeVerified);
1044 
1045   return fileObject.attr("getvalue")();
1046 }
1047 
1048 void PyOperationBase::moveAfter(PyOperationBase &other) {
1049   PyOperation &operation = getOperation();
1050   PyOperation &otherOp = other.getOperation();
1051   operation.checkValid();
1052   otherOp.checkValid();
1053   mlirOperationMoveAfter(operation, otherOp);
1054   operation.parentKeepAlive = otherOp.parentKeepAlive;
1055 }
1056 
1057 void PyOperationBase::moveBefore(PyOperationBase &other) {
1058   PyOperation &operation = getOperation();
1059   PyOperation &otherOp = other.getOperation();
1060   operation.checkValid();
1061   otherOp.checkValid();
1062   mlirOperationMoveBefore(operation, otherOp);
1063   operation.parentKeepAlive = otherOp.parentKeepAlive;
1064 }
1065 
1066 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1067   checkValid();
1068   if (!isAttached())
1069     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1070   MlirOperation operation = mlirOperationGetParentOperation(get());
1071   if (mlirOperationIsNull(operation))
1072     return {};
1073   return PyOperation::forOperation(getContext(), operation);
1074 }
1075 
1076 PyBlock PyOperation::getBlock() {
1077   checkValid();
1078   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1079   MlirBlock block = mlirOperationGetBlock(get());
1080   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1081   assert(parentOperation && "Operation has no parent");
1082   return PyBlock{std::move(*parentOperation), block};
1083 }
1084 
1085 py::object PyOperation::getCapsule() {
1086   checkValid();
1087   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1088 }
1089 
1090 py::object PyOperation::createFromCapsule(py::object capsule) {
1091   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1092   if (mlirOperationIsNull(rawOperation))
1093     throw py::error_already_set();
1094   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1095   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1096       .releaseObject();
1097 }
1098 
1099 static void maybeInsertOperation(PyOperationRef &op,
1100                                  const py::object &maybeIp) {
1101   // InsertPoint active?
1102   if (!maybeIp.is(py::cast(false))) {
1103     PyInsertionPoint *ip;
1104     if (maybeIp.is_none()) {
1105       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1106     } else {
1107       ip = py::cast<PyInsertionPoint *>(maybeIp);
1108     }
1109     if (ip)
1110       ip->insert(*op.get());
1111   }
1112 }
1113 
1114 py::object PyOperation::create(
1115     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1116     llvm::Optional<std::vector<PyValue *>> operands,
1117     llvm::Optional<py::dict> attributes,
1118     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1119     DefaultingPyLocation location, const py::object &maybeIp) {
1120   llvm::SmallVector<MlirValue, 4> mlirOperands;
1121   llvm::SmallVector<MlirType, 4> mlirResults;
1122   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1123   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1124 
1125   // General parameter validation.
1126   if (regions < 0)
1127     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1128 
1129   // Unpack/validate operands.
1130   if (operands) {
1131     mlirOperands.reserve(operands->size());
1132     for (PyValue *operand : *operands) {
1133       if (!operand)
1134         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1135       mlirOperands.push_back(operand->get());
1136     }
1137   }
1138 
1139   // Unpack/validate results.
1140   if (results) {
1141     mlirResults.reserve(results->size());
1142     for (PyType *result : *results) {
1143       // TODO: Verify result type originate from the same context.
1144       if (!result)
1145         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1146       mlirResults.push_back(*result);
1147     }
1148   }
1149   // Unpack/validate attributes.
1150   if (attributes) {
1151     mlirAttributes.reserve(attributes->size());
1152     for (auto &it : *attributes) {
1153       std::string key;
1154       try {
1155         key = it.first.cast<std::string>();
1156       } catch (py::cast_error &err) {
1157         std::string msg = "Invalid attribute key (not a string) when "
1158                           "attempting to create the operation \"" +
1159                           name + "\" (" + err.what() + ")";
1160         throw py::cast_error(msg);
1161       }
1162       try {
1163         auto &attribute = it.second.cast<PyAttribute &>();
1164         // TODO: Verify attribute originates from the same context.
1165         mlirAttributes.emplace_back(std::move(key), attribute);
1166       } catch (py::reference_cast_error &) {
1167         // This exception seems thrown when the value is "None".
1168         std::string msg =
1169             "Found an invalid (`None`?) attribute value for the key \"" + key +
1170             "\" when attempting to create the operation \"" + name + "\"";
1171         throw py::cast_error(msg);
1172       } catch (py::cast_error &err) {
1173         std::string msg = "Invalid attribute value for the key \"" + key +
1174                           "\" when attempting to create the operation \"" +
1175                           name + "\" (" + err.what() + ")";
1176         throw py::cast_error(msg);
1177       }
1178     }
1179   }
1180   // Unpack/validate successors.
1181   if (successors) {
1182     mlirSuccessors.reserve(successors->size());
1183     for (auto *successor : *successors) {
1184       // TODO: Verify successor originate from the same context.
1185       if (!successor)
1186         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1187       mlirSuccessors.push_back(successor->get());
1188     }
1189   }
1190 
1191   // Apply unpacked/validated to the operation state. Beyond this
1192   // point, exceptions cannot be thrown or else the state will leak.
1193   MlirOperationState state =
1194       mlirOperationStateGet(toMlirStringRef(name), location);
1195   if (!mlirOperands.empty())
1196     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1197                                   mlirOperands.data());
1198   if (!mlirResults.empty())
1199     mlirOperationStateAddResults(&state, mlirResults.size(),
1200                                  mlirResults.data());
1201   if (!mlirAttributes.empty()) {
1202     // Note that the attribute names directly reference bytes in
1203     // mlirAttributes, so that vector must not be changed from here
1204     // on.
1205     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1206     mlirNamedAttributes.reserve(mlirAttributes.size());
1207     for (auto &it : mlirAttributes)
1208       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1209           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1210                             toMlirStringRef(it.first)),
1211           it.second));
1212     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1213                                     mlirNamedAttributes.data());
1214   }
1215   if (!mlirSuccessors.empty())
1216     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1217                                     mlirSuccessors.data());
1218   if (regions) {
1219     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1220     mlirRegions.resize(regions);
1221     for (int i = 0; i < regions; ++i)
1222       mlirRegions[i] = mlirRegionCreate();
1223     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1224                                       mlirRegions.data());
1225   }
1226 
1227   // Construct the operation.
1228   MlirOperation operation = mlirOperationCreate(&state);
1229   PyOperationRef created =
1230       PyOperation::createDetached(location->getContext(), operation);
1231   maybeInsertOperation(created, maybeIp);
1232 
1233   return created->createOpView();
1234 }
1235 
1236 py::object PyOperation::clone(const py::object &maybeIp) {
1237   MlirOperation clonedOperation = mlirOperationClone(operation);
1238   PyOperationRef cloned =
1239       PyOperation::createDetached(getContext(), clonedOperation);
1240   maybeInsertOperation(cloned, maybeIp);
1241 
1242   return cloned->createOpView();
1243 }
1244 
1245 py::object PyOperation::createOpView() {
1246   checkValid();
1247   MlirIdentifier ident = mlirOperationGetName(get());
1248   MlirStringRef identStr = mlirIdentifierStr(ident);
1249   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1250       StringRef(identStr.data, identStr.length));
1251   if (opViewClass)
1252     return (*opViewClass)(getRef().getObject());
1253   return py::cast(PyOpView(getRef().getObject()));
1254 }
1255 
1256 void PyOperation::erase() {
1257   checkValid();
1258   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1259   // Python reference to a child operation is live. All children should also
1260   // have their `valid` bit set to false.
1261   auto &liveOperations = getContext()->liveOperations;
1262   if (liveOperations.count(operation.ptr))
1263     liveOperations.erase(operation.ptr);
1264   mlirOperationDestroy(operation);
1265   valid = false;
1266 }
1267 
1268 //------------------------------------------------------------------------------
1269 // PyOpView
1270 //------------------------------------------------------------------------------
1271 
1272 py::object PyOpView::buildGeneric(
1273     const py::object &cls, py::list resultTypeList, py::list operandList,
1274     llvm::Optional<py::dict> attributes,
1275     llvm::Optional<std::vector<PyBlock *>> successors,
1276     llvm::Optional<int> regions, DefaultingPyLocation location,
1277     const py::object &maybeIp) {
1278   PyMlirContextRef context = location->getContext();
1279   // Class level operation construction metadata.
1280   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1281   // Operand and result segment specs are either none, which does no
1282   // variadic unpacking, or a list of ints with segment sizes, where each
1283   // element is either a positive number (typically 1 for a scalar) or -1 to
1284   // indicate that it is derived from the length of the same-indexed operand
1285   // or result (implying that it is a list at that position).
1286   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1287   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1288 
1289   std::vector<uint32_t> operandSegmentLengths;
1290   std::vector<uint32_t> resultSegmentLengths;
1291 
1292   // Validate/determine region count.
1293   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1294   int opMinRegionCount = std::get<0>(opRegionSpec);
1295   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1296   if (!regions) {
1297     regions = opMinRegionCount;
1298   }
1299   if (*regions < opMinRegionCount) {
1300     throw py::value_error(
1301         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1302          llvm::Twine(opMinRegionCount) +
1303          " regions but was built with regions=" + llvm::Twine(*regions))
1304             .str());
1305   }
1306   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1307     throw py::value_error(
1308         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1309          llvm::Twine(opMinRegionCount) +
1310          " regions but was built with regions=" + llvm::Twine(*regions))
1311             .str());
1312   }
1313 
1314   // Unpack results.
1315   std::vector<PyType *> resultTypes;
1316   resultTypes.reserve(resultTypeList.size());
1317   if (resultSegmentSpecObj.is_none()) {
1318     // Non-variadic result unpacking.
1319     for (const auto &it : llvm::enumerate(resultTypeList)) {
1320       try {
1321         resultTypes.push_back(py::cast<PyType *>(it.value()));
1322         if (!resultTypes.back())
1323           throw py::cast_error();
1324       } catch (py::cast_error &err) {
1325         throw py::value_error((llvm::Twine("Result ") +
1326                                llvm::Twine(it.index()) + " of operation \"" +
1327                                name + "\" must be a Type (" + err.what() + ")")
1328                                   .str());
1329       }
1330     }
1331   } else {
1332     // Sized result unpacking.
1333     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1334     if (resultSegmentSpec.size() != resultTypeList.size()) {
1335       throw py::value_error((llvm::Twine("Operation \"") + name +
1336                              "\" requires " +
1337                              llvm::Twine(resultSegmentSpec.size()) +
1338                              " result segments but was provided " +
1339                              llvm::Twine(resultTypeList.size()))
1340                                 .str());
1341     }
1342     resultSegmentLengths.reserve(resultTypeList.size());
1343     for (const auto &it :
1344          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1345       int segmentSpec = std::get<1>(it.value());
1346       if (segmentSpec == 1 || segmentSpec == 0) {
1347         // Unpack unary element.
1348         try {
1349           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1350           if (resultType) {
1351             resultTypes.push_back(resultType);
1352             resultSegmentLengths.push_back(1);
1353           } else if (segmentSpec == 0) {
1354             // Allowed to be optional.
1355             resultSegmentLengths.push_back(0);
1356           } else {
1357             throw py::cast_error("was None and result is not optional");
1358           }
1359         } catch (py::cast_error &err) {
1360           throw py::value_error((llvm::Twine("Result ") +
1361                                  llvm::Twine(it.index()) + " of operation \"" +
1362                                  name + "\" must be a Type (" + err.what() +
1363                                  ")")
1364                                     .str());
1365         }
1366       } else if (segmentSpec == -1) {
1367         // Unpack sequence by appending.
1368         try {
1369           if (std::get<0>(it.value()).is_none()) {
1370             // Treat it as an empty list.
1371             resultSegmentLengths.push_back(0);
1372           } else {
1373             // Unpack the list.
1374             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1375             for (py::object segmentItem : segment) {
1376               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1377               if (!resultTypes.back()) {
1378                 throw py::cast_error("contained a None item");
1379               }
1380             }
1381             resultSegmentLengths.push_back(segment.size());
1382           }
1383         } catch (std::exception &err) {
1384           // NOTE: Sloppy to be using a catch-all here, but there are at least
1385           // three different unrelated exceptions that can be thrown in the
1386           // above "casts". Just keep the scope above small and catch them all.
1387           throw py::value_error((llvm::Twine("Result ") +
1388                                  llvm::Twine(it.index()) + " of operation \"" +
1389                                  name + "\" must be a Sequence of Types (" +
1390                                  err.what() + ")")
1391                                     .str());
1392         }
1393       } else {
1394         throw py::value_error("Unexpected segment spec");
1395       }
1396     }
1397   }
1398 
1399   // Unpack operands.
1400   std::vector<PyValue *> operands;
1401   operands.reserve(operands.size());
1402   if (operandSegmentSpecObj.is_none()) {
1403     // Non-sized operand unpacking.
1404     for (const auto &it : llvm::enumerate(operandList)) {
1405       try {
1406         operands.push_back(py::cast<PyValue *>(it.value()));
1407         if (!operands.back())
1408           throw py::cast_error();
1409       } catch (py::cast_error &err) {
1410         throw py::value_error((llvm::Twine("Operand ") +
1411                                llvm::Twine(it.index()) + " of operation \"" +
1412                                name + "\" must be a Value (" + err.what() + ")")
1413                                   .str());
1414       }
1415     }
1416   } else {
1417     // Sized operand unpacking.
1418     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1419     if (operandSegmentSpec.size() != operandList.size()) {
1420       throw py::value_error((llvm::Twine("Operation \"") + name +
1421                              "\" requires " +
1422                              llvm::Twine(operandSegmentSpec.size()) +
1423                              "operand segments but was provided " +
1424                              llvm::Twine(operandList.size()))
1425                                 .str());
1426     }
1427     operandSegmentLengths.reserve(operandList.size());
1428     for (const auto &it :
1429          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1430       int segmentSpec = std::get<1>(it.value());
1431       if (segmentSpec == 1 || segmentSpec == 0) {
1432         // Unpack unary element.
1433         try {
1434           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1435           if (operandValue) {
1436             operands.push_back(operandValue);
1437             operandSegmentLengths.push_back(1);
1438           } else if (segmentSpec == 0) {
1439             // Allowed to be optional.
1440             operandSegmentLengths.push_back(0);
1441           } else {
1442             throw py::cast_error("was None and operand is not optional");
1443           }
1444         } catch (py::cast_error &err) {
1445           throw py::value_error((llvm::Twine("Operand ") +
1446                                  llvm::Twine(it.index()) + " of operation \"" +
1447                                  name + "\" must be a Value (" + err.what() +
1448                                  ")")
1449                                     .str());
1450         }
1451       } else if (segmentSpec == -1) {
1452         // Unpack sequence by appending.
1453         try {
1454           if (std::get<0>(it.value()).is_none()) {
1455             // Treat it as an empty list.
1456             operandSegmentLengths.push_back(0);
1457           } else {
1458             // Unpack the list.
1459             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1460             for (py::object segmentItem : segment) {
1461               operands.push_back(py::cast<PyValue *>(segmentItem));
1462               if (!operands.back()) {
1463                 throw py::cast_error("contained a None item");
1464               }
1465             }
1466             operandSegmentLengths.push_back(segment.size());
1467           }
1468         } catch (std::exception &err) {
1469           // NOTE: Sloppy to be using a catch-all here, but there are at least
1470           // three different unrelated exceptions that can be thrown in the
1471           // above "casts". Just keep the scope above small and catch them all.
1472           throw py::value_error((llvm::Twine("Operand ") +
1473                                  llvm::Twine(it.index()) + " of operation \"" +
1474                                  name + "\" must be a Sequence of Values (" +
1475                                  err.what() + ")")
1476                                     .str());
1477         }
1478       } else {
1479         throw py::value_error("Unexpected segment spec");
1480       }
1481     }
1482   }
1483 
1484   // Merge operand/result segment lengths into attributes if needed.
1485   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1486     // Dup.
1487     if (attributes) {
1488       attributes = py::dict(*attributes);
1489     } else {
1490       attributes = py::dict();
1491     }
1492     if (attributes->contains("result_segment_sizes") ||
1493         attributes->contains("operand_segment_sizes")) {
1494       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1495                             "'operand_segment_sizes' attribute is unsupported. "
1496                             "Use Operation.create for such low-level access.");
1497     }
1498 
1499     // Add result_segment_sizes attribute.
1500     if (!resultSegmentLengths.empty()) {
1501       int64_t size = resultSegmentLengths.size();
1502       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1503           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1504           resultSegmentLengths.size(), resultSegmentLengths.data());
1505       (*attributes)["result_segment_sizes"] =
1506           PyAttribute(context, segmentLengthAttr);
1507     }
1508 
1509     // Add operand_segment_sizes attribute.
1510     if (!operandSegmentLengths.empty()) {
1511       int64_t size = operandSegmentLengths.size();
1512       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1513           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1514           operandSegmentLengths.size(), operandSegmentLengths.data());
1515       (*attributes)["operand_segment_sizes"] =
1516           PyAttribute(context, segmentLengthAttr);
1517     }
1518   }
1519 
1520   // Delegate to create.
1521   return PyOperation::create(name,
1522                              /*results=*/std::move(resultTypes),
1523                              /*operands=*/std::move(operands),
1524                              /*attributes=*/std::move(attributes),
1525                              /*successors=*/std::move(successors),
1526                              /*regions=*/*regions, location, maybeIp);
1527 }
1528 
1529 PyOpView::PyOpView(const py::object &operationObject)
1530     // Casting through the PyOperationBase base-class and then back to the
1531     // Operation lets us accept any PyOperationBase subclass.
1532     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1533       operationObject(operation.getRef().getObject()) {}
1534 
1535 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1536   // This is... a little gross. The typical pattern is to have a pure python
1537   // class that extends OpView like:
1538   //   class AddFOp(_cext.ir.OpView):
1539   //     def __init__(self, loc, lhs, rhs):
1540   //       operation = loc.context.create_operation(
1541   //           "addf", lhs, rhs, results=[lhs.type])
1542   //       super().__init__(operation)
1543   //
1544   // I.e. The goal of the user facing type is to provide a nice constructor
1545   // that has complete freedom for the op under construction. This is at odds
1546   // with our other desire to sometimes create this object by just passing an
1547   // operation (to initialize the base class). We could do *arg and **kwargs
1548   // munging to try to make it work, but instead, we synthesize a new class
1549   // on the fly which extends this user class (AddFOp in this example) and
1550   // *give it* the base class's __init__ method, thus bypassing the
1551   // intermediate subclass's __init__ method entirely. While slightly,
1552   // underhanded, this is safe/legal because the type hierarchy has not changed
1553   // (we just added a new leaf) and we aren't mucking around with __new__.
1554   // Typically, this new class will be stored on the original as "_Raw" and will
1555   // be used for casts and other things that need a variant of the class that
1556   // is initialized purely from an operation.
1557   py::object parentMetaclass =
1558       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1559   py::dict attributes;
1560   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1561   // now.
1562   //   auto opViewType = py::type::of<PyOpView>();
1563   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1564   attributes["__init__"] = opViewType.attr("__init__");
1565   py::str origName = userClass.attr("__name__");
1566   py::str newName = py::str("_") + origName;
1567   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1568 }
1569 
1570 //------------------------------------------------------------------------------
1571 // PyInsertionPoint.
1572 //------------------------------------------------------------------------------
1573 
1574 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1575 
1576 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1577     : refOperation(beforeOperationBase.getOperation().getRef()),
1578       block((*refOperation)->getBlock()) {}
1579 
1580 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1581   PyOperation &operation = operationBase.getOperation();
1582   if (operation.isAttached())
1583     throw SetPyError(PyExc_ValueError,
1584                      "Attempt to insert operation that is already attached");
1585   block.getParentOperation()->checkValid();
1586   MlirOperation beforeOp = {nullptr};
1587   if (refOperation) {
1588     // Insert before operation.
1589     (*refOperation)->checkValid();
1590     beforeOp = (*refOperation)->get();
1591   } else {
1592     // Insert at end (before null) is only valid if the block does not
1593     // already end in a known terminator (violating this will cause assertion
1594     // failures later).
1595     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1596       throw py::index_error("Cannot insert operation at the end of a block "
1597                             "that already has a terminator. Did you mean to "
1598                             "use 'InsertionPoint.at_block_terminator(block)' "
1599                             "versus 'InsertionPoint(block)'?");
1600     }
1601   }
1602   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1603   operation.setAttached();
1604 }
1605 
1606 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1607   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1608   if (mlirOperationIsNull(firstOp)) {
1609     // Just insert at end.
1610     return PyInsertionPoint(block);
1611   }
1612 
1613   // Insert before first op.
1614   PyOperationRef firstOpRef = PyOperation::forOperation(
1615       block.getParentOperation()->getContext(), firstOp);
1616   return PyInsertionPoint{block, std::move(firstOpRef)};
1617 }
1618 
1619 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1620   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1621   if (mlirOperationIsNull(terminator))
1622     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1623   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1624       block.getParentOperation()->getContext(), terminator);
1625   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1626 }
1627 
1628 py::object PyInsertionPoint::contextEnter() {
1629   return PyThreadContextEntry::pushInsertionPoint(*this);
1630 }
1631 
1632 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1633                                    const pybind11::object &excVal,
1634                                    const pybind11::object &excTb) {
1635   PyThreadContextEntry::popInsertionPoint(*this);
1636 }
1637 
1638 //------------------------------------------------------------------------------
1639 // PyAttribute.
1640 //------------------------------------------------------------------------------
1641 
1642 bool PyAttribute::operator==(const PyAttribute &other) {
1643   return mlirAttributeEqual(attr, other.attr);
1644 }
1645 
1646 py::object PyAttribute::getCapsule() {
1647   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1648 }
1649 
1650 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1651   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1652   if (mlirAttributeIsNull(rawAttr))
1653     throw py::error_already_set();
1654   return PyAttribute(
1655       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1656 }
1657 
1658 //------------------------------------------------------------------------------
1659 // PyNamedAttribute.
1660 //------------------------------------------------------------------------------
1661 
1662 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1663     : ownedName(new std::string(std::move(ownedName))) {
1664   namedAttr = mlirNamedAttributeGet(
1665       mlirIdentifierGet(mlirAttributeGetContext(attr),
1666                         toMlirStringRef(*this->ownedName)),
1667       attr);
1668 }
1669 
1670 //------------------------------------------------------------------------------
1671 // PyType.
1672 //------------------------------------------------------------------------------
1673 
1674 bool PyType::operator==(const PyType &other) {
1675   return mlirTypeEqual(type, other.type);
1676 }
1677 
1678 py::object PyType::getCapsule() {
1679   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1680 }
1681 
1682 PyType PyType::createFromCapsule(py::object capsule) {
1683   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1684   if (mlirTypeIsNull(rawType))
1685     throw py::error_already_set();
1686   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1687                 rawType);
1688 }
1689 
1690 //------------------------------------------------------------------------------
1691 // PyValue and subclases.
1692 //------------------------------------------------------------------------------
1693 
1694 pybind11::object PyValue::getCapsule() {
1695   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1696 }
1697 
1698 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1699   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1700   if (mlirValueIsNull(value))
1701     throw py::error_already_set();
1702   MlirOperation owner;
1703   if (mlirValueIsAOpResult(value))
1704     owner = mlirOpResultGetOwner(value);
1705   if (mlirValueIsABlockArgument(value))
1706     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1707   if (mlirOperationIsNull(owner))
1708     throw py::error_already_set();
1709   MlirContext ctx = mlirOperationGetContext(owner);
1710   PyOperationRef ownerRef =
1711       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1712   return PyValue(ownerRef, value);
1713 }
1714 
1715 //------------------------------------------------------------------------------
1716 // PySymbolTable.
1717 //------------------------------------------------------------------------------
1718 
1719 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1720     : operation(operation.getOperation().getRef()) {
1721   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1722   if (mlirSymbolTableIsNull(symbolTable)) {
1723     throw py::cast_error("Operation is not a Symbol Table.");
1724   }
1725 }
1726 
1727 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1728   operation->checkValid();
1729   MlirOperation symbol = mlirSymbolTableLookup(
1730       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1731   if (mlirOperationIsNull(symbol))
1732     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1733 
1734   return PyOperation::forOperation(operation->getContext(), symbol,
1735                                    operation.getObject())
1736       ->createOpView();
1737 }
1738 
1739 void PySymbolTable::erase(PyOperationBase &symbol) {
1740   operation->checkValid();
1741   symbol.getOperation().checkValid();
1742   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1743   // The operation is also erased, so we must invalidate it. There may be Python
1744   // references to this operation so we don't want to delete it from the list of
1745   // live operations here.
1746   symbol.getOperation().valid = false;
1747 }
1748 
1749 void PySymbolTable::dunderDel(const std::string &name) {
1750   py::object operation = dunderGetItem(name);
1751   erase(py::cast<PyOperationBase &>(operation));
1752 }
1753 
1754 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1755   operation->checkValid();
1756   symbol.getOperation().checkValid();
1757   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1758       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1759   if (mlirAttributeIsNull(symbolAttr))
1760     throw py::value_error("Expected operation to have a symbol name.");
1761   return PyAttribute(
1762       symbol.getOperation().getContext(),
1763       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1764 }
1765 
1766 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1767   // Op must already be a symbol.
1768   PyOperation &operation = symbol.getOperation();
1769   operation.checkValid();
1770   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1771   MlirAttribute existingNameAttr =
1772       mlirOperationGetAttributeByName(operation.get(), attrName);
1773   if (mlirAttributeIsNull(existingNameAttr))
1774     throw py::value_error("Expected operation to have a symbol name.");
1775   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1776 }
1777 
1778 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1779                                   const std::string &name) {
1780   // Op must already be a symbol.
1781   PyOperation &operation = symbol.getOperation();
1782   operation.checkValid();
1783   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1784   MlirAttribute existingNameAttr =
1785       mlirOperationGetAttributeByName(operation.get(), attrName);
1786   if (mlirAttributeIsNull(existingNameAttr))
1787     throw py::value_error("Expected operation to have a symbol name.");
1788   MlirAttribute newNameAttr =
1789       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1790   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1791 }
1792 
1793 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1794   PyOperation &operation = symbol.getOperation();
1795   operation.checkValid();
1796   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1797   MlirAttribute existingVisAttr =
1798       mlirOperationGetAttributeByName(operation.get(), attrName);
1799   if (mlirAttributeIsNull(existingVisAttr))
1800     throw py::value_error("Expected operation to have a symbol visibility.");
1801   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1802 }
1803 
1804 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1805                                   const std::string &visibility) {
1806   if (visibility != "public" && visibility != "private" &&
1807       visibility != "nested")
1808     throw py::value_error(
1809         "Expected visibility to be 'public', 'private' or 'nested'");
1810   PyOperation &operation = symbol.getOperation();
1811   operation.checkValid();
1812   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1813   MlirAttribute existingVisAttr =
1814       mlirOperationGetAttributeByName(operation.get(), attrName);
1815   if (mlirAttributeIsNull(existingVisAttr))
1816     throw py::value_error("Expected operation to have a symbol visibility.");
1817   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1818                                                toMlirStringRef(visibility));
1819   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1820 }
1821 
1822 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1823                                          const std::string &newSymbol,
1824                                          PyOperationBase &from) {
1825   PyOperation &fromOperation = from.getOperation();
1826   fromOperation.checkValid();
1827   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1828           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1829           from.getOperation())))
1830 
1831     throw py::value_error("Symbol rename failed");
1832 }
1833 
1834 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1835                                      bool allSymUsesVisible,
1836                                      py::object callback) {
1837   PyOperation &fromOperation = from.getOperation();
1838   fromOperation.checkValid();
1839   struct UserData {
1840     PyMlirContextRef context;
1841     py::object callback;
1842     bool gotException;
1843     std::string exceptionWhat;
1844     py::object exceptionType;
1845   };
1846   UserData userData{
1847       fromOperation.getContext(), std::move(callback), false, {}, {}};
1848   mlirSymbolTableWalkSymbolTables(
1849       fromOperation.get(), allSymUsesVisible,
1850       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1851         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1852         auto pyFoundOp =
1853             PyOperation::forOperation(calleeUserData->context, foundOp);
1854         if (calleeUserData->gotException)
1855           return;
1856         try {
1857           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1858         } catch (py::error_already_set &e) {
1859           calleeUserData->gotException = true;
1860           calleeUserData->exceptionWhat = e.what();
1861           calleeUserData->exceptionType = e.type();
1862         }
1863       },
1864       static_cast<void *>(&userData));
1865   if (userData.gotException) {
1866     std::string message("Exception raised in callback: ");
1867     message.append(userData.exceptionWhat);
1868     throw std::runtime_error(message);
1869   }
1870 }
1871 
1872 namespace {
1873 /// CRTP base class for Python MLIR values that subclass Value and should be
1874 /// castable from it. The value hierarchy is one level deep and is not supposed
1875 /// to accommodate other levels unless core MLIR changes.
1876 template <typename DerivedTy>
1877 class PyConcreteValue : public PyValue {
1878 public:
1879   // Derived classes must define statics for:
1880   //   IsAFunctionTy isaFunction
1881   //   const char *pyClassName
1882   // and redefine bindDerived.
1883   using ClassTy = py::class_<DerivedTy, PyValue>;
1884   using IsAFunctionTy = bool (*)(MlirValue);
1885 
1886   PyConcreteValue() = default;
1887   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1888       : PyValue(operationRef, value) {}
1889   PyConcreteValue(PyValue &orig)
1890       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1891 
1892   /// Attempts to cast the original value to the derived type and throws on
1893   /// type mismatches.
1894   static MlirValue castFrom(PyValue &orig) {
1895     if (!DerivedTy::isaFunction(orig.get())) {
1896       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1897       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1898                                              DerivedTy::pyClassName +
1899                                              " (from " + origRepr + ")");
1900     }
1901     return orig.get();
1902   }
1903 
1904   /// Binds the Python module objects to functions of this class.
1905   static void bind(py::module &m) {
1906     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1907     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1908     cls.def_static(
1909         "isinstance",
1910         [](PyValue &otherValue) -> bool {
1911           return DerivedTy::isaFunction(otherValue);
1912         },
1913         py::arg("other_value"));
1914     DerivedTy::bindDerived(cls);
1915   }
1916 
1917   /// Implemented by derived classes to add methods to the Python subclass.
1918   static void bindDerived(ClassTy &m) {}
1919 };
1920 
1921 /// Python wrapper for MlirBlockArgument.
1922 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1923 public:
1924   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1925   static constexpr const char *pyClassName = "BlockArgument";
1926   using PyConcreteValue::PyConcreteValue;
1927 
1928   static void bindDerived(ClassTy &c) {
1929     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1930       return PyBlock(self.getParentOperation(),
1931                      mlirBlockArgumentGetOwner(self.get()));
1932     });
1933     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1934       return mlirBlockArgumentGetArgNumber(self.get());
1935     });
1936     c.def(
1937         "set_type",
1938         [](PyBlockArgument &self, PyType type) {
1939           return mlirBlockArgumentSetType(self.get(), type);
1940         },
1941         py::arg("type"));
1942   }
1943 };
1944 
1945 /// Python wrapper for MlirOpResult.
1946 class PyOpResult : public PyConcreteValue<PyOpResult> {
1947 public:
1948   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1949   static constexpr const char *pyClassName = "OpResult";
1950   using PyConcreteValue::PyConcreteValue;
1951 
1952   static void bindDerived(ClassTy &c) {
1953     c.def_property_readonly("owner", [](PyOpResult &self) {
1954       assert(
1955           mlirOperationEqual(self.getParentOperation()->get(),
1956                              mlirOpResultGetOwner(self.get())) &&
1957           "expected the owner of the value in Python to match that in the IR");
1958       return self.getParentOperation().getObject();
1959     });
1960     c.def_property_readonly("result_number", [](PyOpResult &self) {
1961       return mlirOpResultGetResultNumber(self.get());
1962     });
1963   }
1964 };
1965 
1966 /// Returns the list of types of the values held by container.
1967 template <typename Container>
1968 static std::vector<PyType> getValueTypes(Container &container,
1969                                          PyMlirContextRef &context) {
1970   std::vector<PyType> result;
1971   result.reserve(container.getNumElements());
1972   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1973     result.push_back(
1974         PyType(context, mlirValueGetType(container.getElement(i).get())));
1975   }
1976   return result;
1977 }
1978 
1979 /// A list of block arguments. Internally, these are stored as consecutive
1980 /// elements, random access is cheap. The argument list is associated with the
1981 /// operation that contains the block (detached blocks are not allowed in
1982 /// Python bindings) and extends its lifetime.
1983 class PyBlockArgumentList
1984     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1985 public:
1986   static constexpr const char *pyClassName = "BlockArgumentList";
1987 
1988   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1989                       intptr_t startIndex = 0, intptr_t length = -1,
1990                       intptr_t step = 1)
1991       : Sliceable(startIndex,
1992                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1993                   step),
1994         operation(std::move(operation)), block(block) {}
1995 
1996   /// Returns the number of arguments in the list.
1997   intptr_t getNumElements() {
1998     operation->checkValid();
1999     return mlirBlockGetNumArguments(block);
2000   }
2001 
2002   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
2003   PyBlockArgument getElement(intptr_t pos) {
2004     MlirValue argument = mlirBlockGetArgument(block, pos);
2005     return PyBlockArgument(operation, argument);
2006   }
2007 
2008   /// Returns a sublist of this list.
2009   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2010                             intptr_t step) {
2011     return PyBlockArgumentList(operation, block, startIndex, length, step);
2012   }
2013 
2014   static void bindDerived(ClassTy &c) {
2015     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2016       return getValueTypes(self, self.operation->getContext());
2017     });
2018   }
2019 
2020 private:
2021   PyOperationRef operation;
2022   MlirBlock block;
2023 };
2024 
2025 /// A list of operation operands. Internally, these are stored as consecutive
2026 /// elements, random access is cheap. The result list is associated with the
2027 /// operation whose results these are, and extends the lifetime of this
2028 /// operation.
2029 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2030 public:
2031   static constexpr const char *pyClassName = "OpOperandList";
2032 
2033   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2034                   intptr_t length = -1, intptr_t step = 1)
2035       : Sliceable(startIndex,
2036                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2037                                : length,
2038                   step),
2039         operation(operation) {}
2040 
2041   intptr_t getNumElements() {
2042     operation->checkValid();
2043     return mlirOperationGetNumOperands(operation->get());
2044   }
2045 
2046   PyValue getElement(intptr_t pos) {
2047     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2048     MlirOperation owner;
2049     if (mlirValueIsAOpResult(operand))
2050       owner = mlirOpResultGetOwner(operand);
2051     else if (mlirValueIsABlockArgument(operand))
2052       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2053     else
2054       assert(false && "Value must be an block arg or op result.");
2055     PyOperationRef pyOwner =
2056         PyOperation::forOperation(operation->getContext(), owner);
2057     return PyValue(pyOwner, operand);
2058   }
2059 
2060   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2061     return PyOpOperandList(operation, startIndex, length, step);
2062   }
2063 
2064   void dunderSetItem(intptr_t index, PyValue value) {
2065     index = wrapIndex(index);
2066     mlirOperationSetOperand(operation->get(), index, value.get());
2067   }
2068 
2069   static void bindDerived(ClassTy &c) {
2070     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2071   }
2072 
2073 private:
2074   PyOperationRef operation;
2075 };
2076 
2077 /// A list of operation results. Internally, these are stored as consecutive
2078 /// elements, random access is cheap. The result list is associated with the
2079 /// operation whose results these are, and extends the lifetime of this
2080 /// operation.
2081 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2082 public:
2083   static constexpr const char *pyClassName = "OpResultList";
2084 
2085   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2086                  intptr_t length = -1, intptr_t step = 1)
2087       : Sliceable(startIndex,
2088                   length == -1 ? mlirOperationGetNumResults(operation->get())
2089                                : length,
2090                   step),
2091         operation(operation) {}
2092 
2093   intptr_t getNumElements() {
2094     operation->checkValid();
2095     return mlirOperationGetNumResults(operation->get());
2096   }
2097 
2098   PyOpResult getElement(intptr_t index) {
2099     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2100     return PyOpResult(value);
2101   }
2102 
2103   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2104     return PyOpResultList(operation, startIndex, length, step);
2105   }
2106 
2107   static void bindDerived(ClassTy &c) {
2108     c.def_property_readonly("types", [](PyOpResultList &self) {
2109       return getValueTypes(self, self.operation->getContext());
2110     });
2111   }
2112 
2113 private:
2114   PyOperationRef operation;
2115 };
2116 
2117 /// A list of operation attributes. Can be indexed by name, producing
2118 /// attributes, or by index, producing named attributes.
2119 class PyOpAttributeMap {
2120 public:
2121   PyOpAttributeMap(PyOperationRef operation)
2122       : operation(std::move(operation)) {}
2123 
2124   PyAttribute dunderGetItemNamed(const std::string &name) {
2125     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2126                                                          toMlirStringRef(name));
2127     if (mlirAttributeIsNull(attr)) {
2128       throw SetPyError(PyExc_KeyError,
2129                        "attempt to access a non-existent attribute");
2130     }
2131     return PyAttribute(operation->getContext(), attr);
2132   }
2133 
2134   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2135     if (index < 0 || index >= dunderLen()) {
2136       throw SetPyError(PyExc_IndexError,
2137                        "attempt to access out of bounds attribute");
2138     }
2139     MlirNamedAttribute namedAttr =
2140         mlirOperationGetAttribute(operation->get(), index);
2141     return PyNamedAttribute(
2142         namedAttr.attribute,
2143         std::string(mlirIdentifierStr(namedAttr.name).data,
2144                     mlirIdentifierStr(namedAttr.name).length));
2145   }
2146 
2147   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2148     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2149                                     attr);
2150   }
2151 
2152   void dunderDelItem(const std::string &name) {
2153     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2154                                                      toMlirStringRef(name));
2155     if (!removed)
2156       throw SetPyError(PyExc_KeyError,
2157                        "attempt to delete a non-existent attribute");
2158   }
2159 
2160   intptr_t dunderLen() {
2161     return mlirOperationGetNumAttributes(operation->get());
2162   }
2163 
2164   bool dunderContains(const std::string &name) {
2165     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2166         operation->get(), toMlirStringRef(name)));
2167   }
2168 
2169   static void bind(py::module &m) {
2170     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2171         .def("__contains__", &PyOpAttributeMap::dunderContains)
2172         .def("__len__", &PyOpAttributeMap::dunderLen)
2173         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2174         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2175         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2176         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2177   }
2178 
2179 private:
2180   PyOperationRef operation;
2181 };
2182 
2183 } // namespace
2184 
2185 //------------------------------------------------------------------------------
2186 // Populates the core exports of the 'ir' submodule.
2187 //------------------------------------------------------------------------------
2188 
2189 void mlir::python::populateIRCore(py::module &m) {
2190   //----------------------------------------------------------------------------
2191   // Enums.
2192   //----------------------------------------------------------------------------
2193   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2194       .value("ERROR", MlirDiagnosticError)
2195       .value("WARNING", MlirDiagnosticWarning)
2196       .value("NOTE", MlirDiagnosticNote)
2197       .value("REMARK", MlirDiagnosticRemark);
2198 
2199   //----------------------------------------------------------------------------
2200   // Mapping of Diagnostics.
2201   //----------------------------------------------------------------------------
2202   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2203       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2204       .def_property_readonly("location", &PyDiagnostic::getLocation)
2205       .def_property_readonly("message", &PyDiagnostic::getMessage)
2206       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2207       .def("__str__", [](PyDiagnostic &self) -> py::str {
2208         if (!self.isValid())
2209           return "<Invalid Diagnostic>";
2210         return self.getMessage();
2211       });
2212 
2213   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2214       .def("detach", &PyDiagnosticHandler::detach)
2215       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2216       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2217       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2218       .def("__exit__", &PyDiagnosticHandler::contextExit);
2219 
2220   //----------------------------------------------------------------------------
2221   // Mapping of MlirContext.
2222   // Note that this is exported as _BaseContext. The containing, Python level
2223   // __init__.py will subclass it with site-specific functionality and set a
2224   // "Context" attribute on this module.
2225   //----------------------------------------------------------------------------
2226   py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2227       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2228       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2229       .def("_get_context_again",
2230            [](PyMlirContext &self) {
2231              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2232              return ref.releaseObject();
2233            })
2234       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2235       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2236       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2237       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2238                              &PyMlirContext::getCapsule)
2239       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2240       .def("__enter__", &PyMlirContext::contextEnter)
2241       .def("__exit__", &PyMlirContext::contextExit)
2242       .def_property_readonly_static(
2243           "current",
2244           [](py::object & /*class*/) {
2245             auto *context = PyThreadContextEntry::getDefaultContext();
2246             if (!context)
2247               throw SetPyError(PyExc_ValueError, "No current Context");
2248             return context;
2249           },
2250           "Gets the Context bound to the current thread or raises ValueError")
2251       .def_property_readonly(
2252           "dialects",
2253           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2254           "Gets a container for accessing dialects by name")
2255       .def_property_readonly(
2256           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2257           "Alias for 'dialect'")
2258       .def(
2259           "get_dialect_descriptor",
2260           [=](PyMlirContext &self, std::string &name) {
2261             MlirDialect dialect = mlirContextGetOrLoadDialect(
2262                 self.get(), {name.data(), name.size()});
2263             if (mlirDialectIsNull(dialect)) {
2264               throw SetPyError(PyExc_ValueError,
2265                                Twine("Dialect '") + name + "' not found");
2266             }
2267             return PyDialectDescriptor(self.getRef(), dialect);
2268           },
2269           py::arg("dialect_name"),
2270           "Gets or loads a dialect by name, returning its descriptor object")
2271       .def_property(
2272           "allow_unregistered_dialects",
2273           [](PyMlirContext &self) -> bool {
2274             return mlirContextGetAllowUnregisteredDialects(self.get());
2275           },
2276           [](PyMlirContext &self, bool value) {
2277             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2278           })
2279       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2280            py::arg("callback"),
2281            "Attaches a diagnostic handler that will receive callbacks")
2282       .def(
2283           "enable_multithreading",
2284           [](PyMlirContext &self, bool enable) {
2285             mlirContextEnableMultithreading(self.get(), enable);
2286           },
2287           py::arg("enable"))
2288       .def(
2289           "is_registered_operation",
2290           [](PyMlirContext &self, std::string &name) {
2291             return mlirContextIsRegisteredOperation(
2292                 self.get(), MlirStringRef{name.data(), name.size()});
2293           },
2294           py::arg("operation_name"))
2295       .def(
2296           "append_dialect_registry",
2297           [](PyMlirContext &self, PyDialectRegistry &registry) {
2298             mlirContextAppendDialectRegistry(self.get(), registry);
2299           },
2300           py::arg("registry"))
2301       .def("load_all_available_dialects", [](PyMlirContext &self) {
2302         mlirContextLoadAllAvailableDialects(self.get());
2303       });
2304 
2305   //----------------------------------------------------------------------------
2306   // Mapping of PyDialectDescriptor
2307   //----------------------------------------------------------------------------
2308   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2309       .def_property_readonly("namespace",
2310                              [](PyDialectDescriptor &self) {
2311                                MlirStringRef ns =
2312                                    mlirDialectGetNamespace(self.get());
2313                                return py::str(ns.data, ns.length);
2314                              })
2315       .def("__repr__", [](PyDialectDescriptor &self) {
2316         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2317         std::string repr("<DialectDescriptor ");
2318         repr.append(ns.data, ns.length);
2319         repr.append(">");
2320         return repr;
2321       });
2322 
2323   //----------------------------------------------------------------------------
2324   // Mapping of PyDialects
2325   //----------------------------------------------------------------------------
2326   py::class_<PyDialects>(m, "Dialects", py::module_local())
2327       .def("__getitem__",
2328            [=](PyDialects &self, std::string keyName) {
2329              MlirDialect dialect =
2330                  self.getDialectForKey(keyName, /*attrError=*/false);
2331              py::object descriptor =
2332                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2333              return createCustomDialectWrapper(keyName, std::move(descriptor));
2334            })
2335       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2336         MlirDialect dialect =
2337             self.getDialectForKey(attrName, /*attrError=*/true);
2338         py::object descriptor =
2339             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2340         return createCustomDialectWrapper(attrName, std::move(descriptor));
2341       });
2342 
2343   //----------------------------------------------------------------------------
2344   // Mapping of PyDialect
2345   //----------------------------------------------------------------------------
2346   py::class_<PyDialect>(m, "Dialect", py::module_local())
2347       .def(py::init<py::object>(), py::arg("descriptor"))
2348       .def_property_readonly(
2349           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2350       .def("__repr__", [](py::object self) {
2351         auto clazz = self.attr("__class__");
2352         return py::str("<Dialect ") +
2353                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2354                clazz.attr("__module__") + py::str(".") +
2355                clazz.attr("__name__") + py::str(")>");
2356       });
2357 
2358   //----------------------------------------------------------------------------
2359   // Mapping of PyDialectRegistry
2360   //----------------------------------------------------------------------------
2361   py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2362       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2363                              &PyDialectRegistry::getCapsule)
2364       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2365       .def(py::init<>());
2366 
2367   //----------------------------------------------------------------------------
2368   // Mapping of Location
2369   //----------------------------------------------------------------------------
2370   py::class_<PyLocation>(m, "Location", py::module_local())
2371       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2372       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2373       .def("__enter__", &PyLocation::contextEnter)
2374       .def("__exit__", &PyLocation::contextExit)
2375       .def("__eq__",
2376            [](PyLocation &self, PyLocation &other) -> bool {
2377              return mlirLocationEqual(self, other);
2378            })
2379       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2380       .def_property_readonly_static(
2381           "current",
2382           [](py::object & /*class*/) {
2383             auto *loc = PyThreadContextEntry::getDefaultLocation();
2384             if (!loc)
2385               throw SetPyError(PyExc_ValueError, "No current Location");
2386             return loc;
2387           },
2388           "Gets the Location bound to the current thread or raises ValueError")
2389       .def_static(
2390           "unknown",
2391           [](DefaultingPyMlirContext context) {
2392             return PyLocation(context->getRef(),
2393                               mlirLocationUnknownGet(context->get()));
2394           },
2395           py::arg("context") = py::none(),
2396           "Gets a Location representing an unknown location")
2397       .def_static(
2398           "callsite",
2399           [](PyLocation callee, const std::vector<PyLocation> &frames,
2400              DefaultingPyMlirContext context) {
2401             if (frames.empty())
2402               throw py::value_error("No caller frames provided");
2403             MlirLocation caller = frames.back().get();
2404             for (const PyLocation &frame :
2405                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2406               caller = mlirLocationCallSiteGet(frame.get(), caller);
2407             return PyLocation(context->getRef(),
2408                               mlirLocationCallSiteGet(callee.get(), caller));
2409           },
2410           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2411           kContextGetCallSiteLocationDocstring)
2412       .def_static(
2413           "file",
2414           [](std::string filename, int line, int col,
2415              DefaultingPyMlirContext context) {
2416             return PyLocation(
2417                 context->getRef(),
2418                 mlirLocationFileLineColGet(
2419                     context->get(), toMlirStringRef(filename), line, col));
2420           },
2421           py::arg("filename"), py::arg("line"), py::arg("col"),
2422           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2423       .def_static(
2424           "fused",
2425           [](const std::vector<PyLocation> &pyLocations,
2426              llvm::Optional<PyAttribute> metadata,
2427              DefaultingPyMlirContext context) {
2428             llvm::SmallVector<MlirLocation, 4> locations;
2429             locations.reserve(pyLocations.size());
2430             for (auto &pyLocation : pyLocations)
2431               locations.push_back(pyLocation.get());
2432             MlirLocation location = mlirLocationFusedGet(
2433                 context->get(), locations.size(), locations.data(),
2434                 metadata ? metadata->get() : MlirAttribute{0});
2435             return PyLocation(context->getRef(), location);
2436           },
2437           py::arg("locations"), py::arg("metadata") = py::none(),
2438           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2439       .def_static(
2440           "name",
2441           [](std::string name, llvm::Optional<PyLocation> childLoc,
2442              DefaultingPyMlirContext context) {
2443             return PyLocation(
2444                 context->getRef(),
2445                 mlirLocationNameGet(
2446                     context->get(), toMlirStringRef(name),
2447                     childLoc ? childLoc->get()
2448                              : mlirLocationUnknownGet(context->get())));
2449           },
2450           py::arg("name"), py::arg("childLoc") = py::none(),
2451           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2452       .def_property_readonly(
2453           "context",
2454           [](PyLocation &self) { return self.getContext().getObject(); },
2455           "Context that owns the Location")
2456       .def(
2457           "emit_error",
2458           [](PyLocation &self, std::string message) {
2459             mlirEmitError(self, message.c_str());
2460           },
2461           py::arg("message"), "Emits an error at this location")
2462       .def("__repr__", [](PyLocation &self) {
2463         PyPrintAccumulator printAccum;
2464         mlirLocationPrint(self, printAccum.getCallback(),
2465                           printAccum.getUserData());
2466         return printAccum.join();
2467       });
2468 
2469   //----------------------------------------------------------------------------
2470   // Mapping of Module
2471   //----------------------------------------------------------------------------
2472   py::class_<PyModule>(m, "Module", py::module_local())
2473       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2474       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2475       .def_static(
2476           "parse",
2477           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2478             MlirModule module = mlirModuleCreateParse(
2479                 context->get(), toMlirStringRef(moduleAsm));
2480             // TODO: Rework error reporting once diagnostic engine is exposed
2481             // in C API.
2482             if (mlirModuleIsNull(module)) {
2483               throw SetPyError(
2484                   PyExc_ValueError,
2485                   "Unable to parse module assembly (see diagnostics)");
2486             }
2487             return PyModule::forModule(module).releaseObject();
2488           },
2489           py::arg("asm"), py::arg("context") = py::none(),
2490           kModuleParseDocstring)
2491       .def_static(
2492           "create",
2493           [](DefaultingPyLocation loc) {
2494             MlirModule module = mlirModuleCreateEmpty(loc);
2495             return PyModule::forModule(module).releaseObject();
2496           },
2497           py::arg("loc") = py::none(), "Creates an empty module")
2498       .def_property_readonly(
2499           "context",
2500           [](PyModule &self) { return self.getContext().getObject(); },
2501           "Context that created the Module")
2502       .def_property_readonly(
2503           "operation",
2504           [](PyModule &self) {
2505             return PyOperation::forOperation(self.getContext(),
2506                                              mlirModuleGetOperation(self.get()),
2507                                              self.getRef().releaseObject())
2508                 .releaseObject();
2509           },
2510           "Accesses the module as an operation")
2511       .def_property_readonly(
2512           "body",
2513           [](PyModule &self) {
2514             PyOperationRef moduleOp = PyOperation::forOperation(
2515                 self.getContext(), mlirModuleGetOperation(self.get()),
2516                 self.getRef().releaseObject());
2517             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2518             return returnBlock;
2519           },
2520           "Return the block for this module")
2521       .def(
2522           "dump",
2523           [](PyModule &self) {
2524             mlirOperationDump(mlirModuleGetOperation(self.get()));
2525           },
2526           kDumpDocstring)
2527       .def(
2528           "__str__",
2529           [](py::object self) {
2530             // Defer to the operation's __str__.
2531             return self.attr("operation").attr("__str__")();
2532           },
2533           kOperationStrDunderDocstring);
2534 
2535   //----------------------------------------------------------------------------
2536   // Mapping of Operation.
2537   //----------------------------------------------------------------------------
2538   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2539       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2540                              [](PyOperationBase &self) {
2541                                return self.getOperation().getCapsule();
2542                              })
2543       .def("__eq__",
2544            [](PyOperationBase &self, PyOperationBase &other) {
2545              return &self.getOperation() == &other.getOperation();
2546            })
2547       .def("__eq__",
2548            [](PyOperationBase &self, py::object other) { return false; })
2549       .def("__hash__",
2550            [](PyOperationBase &self) {
2551              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2552            })
2553       .def_property_readonly("attributes",
2554                              [](PyOperationBase &self) {
2555                                return PyOpAttributeMap(
2556                                    self.getOperation().getRef());
2557                              })
2558       .def_property_readonly("operands",
2559                              [](PyOperationBase &self) {
2560                                return PyOpOperandList(
2561                                    self.getOperation().getRef());
2562                              })
2563       .def_property_readonly("regions",
2564                              [](PyOperationBase &self) {
2565                                return PyRegionList(
2566                                    self.getOperation().getRef());
2567                              })
2568       .def_property_readonly(
2569           "results",
2570           [](PyOperationBase &self) {
2571             return PyOpResultList(self.getOperation().getRef());
2572           },
2573           "Returns the list of Operation results.")
2574       .def_property_readonly(
2575           "result",
2576           [](PyOperationBase &self) {
2577             auto &operation = self.getOperation();
2578             auto numResults = mlirOperationGetNumResults(operation);
2579             if (numResults != 1) {
2580               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2581               throw SetPyError(
2582                   PyExc_ValueError,
2583                   Twine("Cannot call .result on operation ") +
2584                       StringRef(name.data, name.length) + " which has " +
2585                       Twine(numResults) +
2586                       " results (it is only valid for operations with a "
2587                       "single result)");
2588             }
2589             return PyOpResult(operation.getRef(),
2590                               mlirOperationGetResult(operation, 0));
2591           },
2592           "Shortcut to get an op result if it has only one (throws an error "
2593           "otherwise).")
2594       .def_property_readonly(
2595           "location",
2596           [](PyOperationBase &self) {
2597             PyOperation &operation = self.getOperation();
2598             return PyLocation(operation.getContext(),
2599                               mlirOperationGetLocation(operation.get()));
2600           },
2601           "Returns the source location the operation was defined or derived "
2602           "from.")
2603       .def(
2604           "__str__",
2605           [](PyOperationBase &self) {
2606             return self.getAsm(/*binary=*/false,
2607                                /*largeElementsLimit=*/llvm::None,
2608                                /*enableDebugInfo=*/false,
2609                                /*prettyDebugInfo=*/false,
2610                                /*printGenericOpForm=*/false,
2611                                /*useLocalScope=*/false,
2612                                /*assumeVerified=*/false);
2613           },
2614           "Returns the assembly form of the operation.")
2615       .def("print", &PyOperationBase::print,
2616            // Careful: Lots of arguments must match up with print method.
2617            py::arg("file") = py::none(), py::arg("binary") = false,
2618            py::arg("large_elements_limit") = py::none(),
2619            py::arg("enable_debug_info") = false,
2620            py::arg("pretty_debug_info") = false,
2621            py::arg("print_generic_op_form") = false,
2622            py::arg("use_local_scope") = false,
2623            py::arg("assume_verified") = false, kOperationPrintDocstring)
2624       .def("get_asm", &PyOperationBase::getAsm,
2625            // Careful: Lots of arguments must match up with get_asm method.
2626            py::arg("binary") = false,
2627            py::arg("large_elements_limit") = py::none(),
2628            py::arg("enable_debug_info") = false,
2629            py::arg("pretty_debug_info") = false,
2630            py::arg("print_generic_op_form") = false,
2631            py::arg("use_local_scope") = false,
2632            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2633       .def(
2634           "verify",
2635           [](PyOperationBase &self) {
2636             return mlirOperationVerify(self.getOperation());
2637           },
2638           "Verify the operation and return true if it passes, false if it "
2639           "fails.")
2640       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2641            "Puts self immediately after the other operation in its parent "
2642            "block.")
2643       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2644            "Puts self immediately before the other operation in its parent "
2645            "block.")
2646       .def(
2647           "detach_from_parent",
2648           [](PyOperationBase &self) {
2649             PyOperation &operation = self.getOperation();
2650             operation.checkValid();
2651             if (!operation.isAttached())
2652               throw py::value_error("Detached operation has no parent.");
2653 
2654             operation.detachFromParent();
2655             return operation.createOpView();
2656           },
2657           "Detaches the operation from its parent block.");
2658 
2659   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2660       .def_static("create", &PyOperation::create, py::arg("name"),
2661                   py::arg("results") = py::none(),
2662                   py::arg("operands") = py::none(),
2663                   py::arg("attributes") = py::none(),
2664                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2665                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2666                   kOperationCreateDocstring)
2667       .def_property_readonly("parent",
2668                              [](PyOperation &self) -> py::object {
2669                                auto parent = self.getParentOperation();
2670                                if (parent)
2671                                  return parent->getObject();
2672                                return py::none();
2673                              })
2674       .def("erase", &PyOperation::erase)
2675       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2676       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2677                              &PyOperation::getCapsule)
2678       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2679       .def_property_readonly("name",
2680                              [](PyOperation &self) {
2681                                self.checkValid();
2682                                MlirOperation operation = self.get();
2683                                MlirStringRef name = mlirIdentifierStr(
2684                                    mlirOperationGetName(operation));
2685                                return py::str(name.data, name.length);
2686                              })
2687       .def_property_readonly(
2688           "context",
2689           [](PyOperation &self) {
2690             self.checkValid();
2691             return self.getContext().getObject();
2692           },
2693           "Context that owns the Operation")
2694       .def_property_readonly("opview", &PyOperation::createOpView);
2695 
2696   auto opViewClass =
2697       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2698           .def(py::init<py::object>(), py::arg("operation"))
2699           .def_property_readonly("operation", &PyOpView::getOperationObject)
2700           .def_property_readonly(
2701               "context",
2702               [](PyOpView &self) {
2703                 return self.getOperation().getContext().getObject();
2704               },
2705               "Context that owns the Operation")
2706           .def("__str__", [](PyOpView &self) {
2707             return py::str(self.getOperationObject());
2708           });
2709   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2710   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2711   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2712   opViewClass.attr("build_generic") = classmethod(
2713       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2714       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2715       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2716       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2717       "Builds a specific, generated OpView based on class level attributes.");
2718 
2719   //----------------------------------------------------------------------------
2720   // Mapping of PyRegion.
2721   //----------------------------------------------------------------------------
2722   py::class_<PyRegion>(m, "Region", py::module_local())
2723       .def_property_readonly(
2724           "blocks",
2725           [](PyRegion &self) {
2726             return PyBlockList(self.getParentOperation(), self.get());
2727           },
2728           "Returns a forward-optimized sequence of blocks.")
2729       .def_property_readonly(
2730           "owner",
2731           [](PyRegion &self) {
2732             return self.getParentOperation()->createOpView();
2733           },
2734           "Returns the operation owning this region.")
2735       .def(
2736           "__iter__",
2737           [](PyRegion &self) {
2738             self.checkValid();
2739             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2740             return PyBlockIterator(self.getParentOperation(), firstBlock);
2741           },
2742           "Iterates over blocks in the region.")
2743       .def("__eq__",
2744            [](PyRegion &self, PyRegion &other) {
2745              return self.get().ptr == other.get().ptr;
2746            })
2747       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2748 
2749   //----------------------------------------------------------------------------
2750   // Mapping of PyBlock.
2751   //----------------------------------------------------------------------------
2752   py::class_<PyBlock>(m, "Block", py::module_local())
2753       .def_property_readonly(
2754           "owner",
2755           [](PyBlock &self) {
2756             return self.getParentOperation()->createOpView();
2757           },
2758           "Returns the owning operation of this block.")
2759       .def_property_readonly(
2760           "region",
2761           [](PyBlock &self) {
2762             MlirRegion region = mlirBlockGetParentRegion(self.get());
2763             return PyRegion(self.getParentOperation(), region);
2764           },
2765           "Returns the owning region of this block.")
2766       .def_property_readonly(
2767           "arguments",
2768           [](PyBlock &self) {
2769             return PyBlockArgumentList(self.getParentOperation(), self.get());
2770           },
2771           "Returns a list of block arguments.")
2772       .def_property_readonly(
2773           "operations",
2774           [](PyBlock &self) {
2775             return PyOperationList(self.getParentOperation(), self.get());
2776           },
2777           "Returns a forward-optimized sequence of operations.")
2778       .def_static(
2779           "create_at_start",
2780           [](PyRegion &parent, py::list pyArgTypes) {
2781             parent.checkValid();
2782             llvm::SmallVector<MlirType, 4> argTypes;
2783             llvm::SmallVector<MlirLocation, 4> argLocs;
2784             argTypes.reserve(pyArgTypes.size());
2785             argLocs.reserve(pyArgTypes.size());
2786             for (auto &pyArg : pyArgTypes) {
2787               argTypes.push_back(pyArg.cast<PyType &>());
2788               // TODO: Pass in a proper location here.
2789               argLocs.push_back(
2790                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2791             }
2792 
2793             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2794                                               argLocs.data());
2795             mlirRegionInsertOwnedBlock(parent, 0, block);
2796             return PyBlock(parent.getParentOperation(), block);
2797           },
2798           py::arg("parent"), py::arg("arg_types") = py::list(),
2799           "Creates and returns a new Block at the beginning of the given "
2800           "region (with given argument types).")
2801       .def(
2802           "append_to",
2803           [](PyBlock &self, PyRegion &region) {
2804             MlirBlock b = self.get();
2805             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2806               mlirBlockDetach(b);
2807             mlirRegionAppendOwnedBlock(region.get(), b);
2808           },
2809           "Append this block to a region, transferring ownership if necessary")
2810       .def(
2811           "create_before",
2812           [](PyBlock &self, py::args pyArgTypes) {
2813             self.checkValid();
2814             llvm::SmallVector<MlirType, 4> argTypes;
2815             llvm::SmallVector<MlirLocation, 4> argLocs;
2816             argTypes.reserve(pyArgTypes.size());
2817             argLocs.reserve(pyArgTypes.size());
2818             for (auto &pyArg : pyArgTypes) {
2819               argTypes.push_back(pyArg.cast<PyType &>());
2820               // TODO: Pass in a proper location here.
2821               argLocs.push_back(
2822                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2823             }
2824 
2825             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2826                                               argLocs.data());
2827             MlirRegion region = mlirBlockGetParentRegion(self.get());
2828             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2829             return PyBlock(self.getParentOperation(), block);
2830           },
2831           "Creates and returns a new Block before this block "
2832           "(with given argument types).")
2833       .def(
2834           "create_after",
2835           [](PyBlock &self, py::args pyArgTypes) {
2836             self.checkValid();
2837             llvm::SmallVector<MlirType, 4> argTypes;
2838             llvm::SmallVector<MlirLocation, 4> argLocs;
2839             argTypes.reserve(pyArgTypes.size());
2840             argLocs.reserve(pyArgTypes.size());
2841             for (auto &pyArg : pyArgTypes) {
2842               argTypes.push_back(pyArg.cast<PyType &>());
2843 
2844               // TODO: Pass in a proper location here.
2845               argLocs.push_back(
2846                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2847             }
2848             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2849                                               argLocs.data());
2850             MlirRegion region = mlirBlockGetParentRegion(self.get());
2851             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2852             return PyBlock(self.getParentOperation(), block);
2853           },
2854           "Creates and returns a new Block after this block "
2855           "(with given argument types).")
2856       .def(
2857           "__iter__",
2858           [](PyBlock &self) {
2859             self.checkValid();
2860             MlirOperation firstOperation =
2861                 mlirBlockGetFirstOperation(self.get());
2862             return PyOperationIterator(self.getParentOperation(),
2863                                        firstOperation);
2864           },
2865           "Iterates over operations in the block.")
2866       .def("__eq__",
2867            [](PyBlock &self, PyBlock &other) {
2868              return self.get().ptr == other.get().ptr;
2869            })
2870       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2871       .def(
2872           "__str__",
2873           [](PyBlock &self) {
2874             self.checkValid();
2875             PyPrintAccumulator printAccum;
2876             mlirBlockPrint(self.get(), printAccum.getCallback(),
2877                            printAccum.getUserData());
2878             return printAccum.join();
2879           },
2880           "Returns the assembly form of the block.")
2881       .def(
2882           "append",
2883           [](PyBlock &self, PyOperationBase &operation) {
2884             if (operation.getOperation().isAttached())
2885               operation.getOperation().detachFromParent();
2886 
2887             MlirOperation mlirOperation = operation.getOperation().get();
2888             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2889             operation.getOperation().setAttached(
2890                 self.getParentOperation().getObject());
2891           },
2892           py::arg("operation"),
2893           "Appends an operation to this block. If the operation is currently "
2894           "in another block, it will be moved.");
2895 
2896   //----------------------------------------------------------------------------
2897   // Mapping of PyInsertionPoint.
2898   //----------------------------------------------------------------------------
2899 
2900   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2901       .def(py::init<PyBlock &>(), py::arg("block"),
2902            "Inserts after the last operation but still inside the block.")
2903       .def("__enter__", &PyInsertionPoint::contextEnter)
2904       .def("__exit__", &PyInsertionPoint::contextExit)
2905       .def_property_readonly_static(
2906           "current",
2907           [](py::object & /*class*/) {
2908             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2909             if (!ip)
2910               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2911             return ip;
2912           },
2913           "Gets the InsertionPoint bound to the current thread or raises "
2914           "ValueError if none has been set")
2915       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2916            "Inserts before a referenced operation.")
2917       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2918                   py::arg("block"), "Inserts at the beginning of the block.")
2919       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2920                   py::arg("block"), "Inserts before the block terminator.")
2921       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2922            "Inserts an operation.")
2923       .def_property_readonly(
2924           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2925           "Returns the block that this InsertionPoint points to.");
2926 
2927   //----------------------------------------------------------------------------
2928   // Mapping of PyAttribute.
2929   //----------------------------------------------------------------------------
2930   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2931       // Delegate to the PyAttribute copy constructor, which will also lifetime
2932       // extend the backing context which owns the MlirAttribute.
2933       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2934            "Casts the passed attribute to the generic Attribute")
2935       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2936                              &PyAttribute::getCapsule)
2937       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2938       .def_static(
2939           "parse",
2940           [](std::string attrSpec, DefaultingPyMlirContext context) {
2941             MlirAttribute type = mlirAttributeParseGet(
2942                 context->get(), toMlirStringRef(attrSpec));
2943             // TODO: Rework error reporting once diagnostic engine is exposed
2944             // in C API.
2945             if (mlirAttributeIsNull(type)) {
2946               throw SetPyError(PyExc_ValueError,
2947                                Twine("Unable to parse attribute: '") +
2948                                    attrSpec + "'");
2949             }
2950             return PyAttribute(context->getRef(), type);
2951           },
2952           py::arg("asm"), py::arg("context") = py::none(),
2953           "Parses an attribute from an assembly form")
2954       .def_property_readonly(
2955           "context",
2956           [](PyAttribute &self) { return self.getContext().getObject(); },
2957           "Context that owns the Attribute")
2958       .def_property_readonly("type",
2959                              [](PyAttribute &self) {
2960                                return PyType(self.getContext()->getRef(),
2961                                              mlirAttributeGetType(self));
2962                              })
2963       .def(
2964           "get_named",
2965           [](PyAttribute &self, std::string name) {
2966             return PyNamedAttribute(self, std::move(name));
2967           },
2968           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2969       .def("__eq__",
2970            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2971       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2972       .def("__hash__",
2973            [](PyAttribute &self) {
2974              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2975            })
2976       .def(
2977           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2978           kDumpDocstring)
2979       .def(
2980           "__str__",
2981           [](PyAttribute &self) {
2982             PyPrintAccumulator printAccum;
2983             mlirAttributePrint(self, printAccum.getCallback(),
2984                                printAccum.getUserData());
2985             return printAccum.join();
2986           },
2987           "Returns the assembly form of the Attribute.")
2988       .def("__repr__", [](PyAttribute &self) {
2989         // Generally, assembly formats are not printed for __repr__ because
2990         // this can cause exceptionally long debug output and exceptions.
2991         // However, attribute values are generally considered useful and are
2992         // printed. This may need to be re-evaluated if debug dumps end up
2993         // being excessive.
2994         PyPrintAccumulator printAccum;
2995         printAccum.parts.append("Attribute(");
2996         mlirAttributePrint(self, printAccum.getCallback(),
2997                            printAccum.getUserData());
2998         printAccum.parts.append(")");
2999         return printAccum.join();
3000       });
3001 
3002   //----------------------------------------------------------------------------
3003   // Mapping of PyNamedAttribute
3004   //----------------------------------------------------------------------------
3005   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3006       .def("__repr__",
3007            [](PyNamedAttribute &self) {
3008              PyPrintAccumulator printAccum;
3009              printAccum.parts.append("NamedAttribute(");
3010              printAccum.parts.append(
3011                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3012                          mlirIdentifierStr(self.namedAttr.name).length));
3013              printAccum.parts.append("=");
3014              mlirAttributePrint(self.namedAttr.attribute,
3015                                 printAccum.getCallback(),
3016                                 printAccum.getUserData());
3017              printAccum.parts.append(")");
3018              return printAccum.join();
3019            })
3020       .def_property_readonly(
3021           "name",
3022           [](PyNamedAttribute &self) {
3023             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3024                            mlirIdentifierStr(self.namedAttr.name).length);
3025           },
3026           "The name of the NamedAttribute binding")
3027       .def_property_readonly(
3028           "attr",
3029           [](PyNamedAttribute &self) {
3030             // TODO: When named attribute is removed/refactored, also remove
3031             // this constructor (it does an inefficient table lookup).
3032             auto contextRef = PyMlirContext::forContext(
3033                 mlirAttributeGetContext(self.namedAttr.attribute));
3034             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3035           },
3036           py::keep_alive<0, 1>(),
3037           "The underlying generic attribute of the NamedAttribute binding");
3038 
3039   //----------------------------------------------------------------------------
3040   // Mapping of PyType.
3041   //----------------------------------------------------------------------------
3042   py::class_<PyType>(m, "Type", py::module_local())
3043       // Delegate to the PyType copy constructor, which will also lifetime
3044       // extend the backing context which owns the MlirType.
3045       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3046            "Casts the passed type to the generic Type")
3047       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3048       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3049       .def_static(
3050           "parse",
3051           [](std::string typeSpec, DefaultingPyMlirContext context) {
3052             MlirType type =
3053                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3054             // TODO: Rework error reporting once diagnostic engine is exposed
3055             // in C API.
3056             if (mlirTypeIsNull(type)) {
3057               throw SetPyError(PyExc_ValueError,
3058                                Twine("Unable to parse type: '") + typeSpec +
3059                                    "'");
3060             }
3061             return PyType(context->getRef(), type);
3062           },
3063           py::arg("asm"), py::arg("context") = py::none(),
3064           kContextParseTypeDocstring)
3065       .def_property_readonly(
3066           "context", [](PyType &self) { return self.getContext().getObject(); },
3067           "Context that owns the Type")
3068       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3069       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3070       .def("__hash__",
3071            [](PyType &self) {
3072              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3073            })
3074       .def(
3075           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3076       .def(
3077           "__str__",
3078           [](PyType &self) {
3079             PyPrintAccumulator printAccum;
3080             mlirTypePrint(self, printAccum.getCallback(),
3081                           printAccum.getUserData());
3082             return printAccum.join();
3083           },
3084           "Returns the assembly form of the type.")
3085       .def("__repr__", [](PyType &self) {
3086         // Generally, assembly formats are not printed for __repr__ because
3087         // this can cause exceptionally long debug output and exceptions.
3088         // However, types are an exception as they typically have compact
3089         // assembly forms and printing them is useful.
3090         PyPrintAccumulator printAccum;
3091         printAccum.parts.append("Type(");
3092         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3093         printAccum.parts.append(")");
3094         return printAccum.join();
3095       });
3096 
3097   //----------------------------------------------------------------------------
3098   // Mapping of Value.
3099   //----------------------------------------------------------------------------
3100   py::class_<PyValue>(m, "Value", py::module_local())
3101       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3102       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3103       .def_property_readonly(
3104           "context",
3105           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3106           "Context in which the value lives.")
3107       .def(
3108           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3109           kDumpDocstring)
3110       .def_property_readonly(
3111           "owner",
3112           [](PyValue &self) {
3113             assert(mlirOperationEqual(self.getParentOperation()->get(),
3114                                       mlirOpResultGetOwner(self.get())) &&
3115                    "expected the owner of the value in Python to match that in "
3116                    "the IR");
3117             return self.getParentOperation().getObject();
3118           })
3119       .def("__eq__",
3120            [](PyValue &self, PyValue &other) {
3121              return self.get().ptr == other.get().ptr;
3122            })
3123       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3124       .def("__hash__",
3125            [](PyValue &self) {
3126              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3127            })
3128       .def(
3129           "__str__",
3130           [](PyValue &self) {
3131             PyPrintAccumulator printAccum;
3132             printAccum.parts.append("Value(");
3133             mlirValuePrint(self.get(), printAccum.getCallback(),
3134                            printAccum.getUserData());
3135             printAccum.parts.append(")");
3136             return printAccum.join();
3137           },
3138           kValueDunderStrDocstring)
3139       .def_property_readonly("type", [](PyValue &self) {
3140         return PyType(self.getParentOperation()->getContext(),
3141                       mlirValueGetType(self.get()));
3142       });
3143   PyBlockArgument::bind(m);
3144   PyOpResult::bind(m);
3145 
3146   //----------------------------------------------------------------------------
3147   // Mapping of SymbolTable.
3148   //----------------------------------------------------------------------------
3149   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3150       .def(py::init<PyOperationBase &>())
3151       .def("__getitem__", &PySymbolTable::dunderGetItem)
3152       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3153       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3154       .def("__delitem__", &PySymbolTable::dunderDel)
3155       .def("__contains__",
3156            [](PySymbolTable &table, const std::string &name) {
3157              return !mlirOperationIsNull(mlirSymbolTableLookup(
3158                  table, mlirStringRefCreate(name.data(), name.length())));
3159            })
3160       // Static helpers.
3161       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3162                   py::arg("symbol"), py::arg("name"))
3163       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3164                   py::arg("symbol"))
3165       .def_static("get_visibility", &PySymbolTable::getVisibility,
3166                   py::arg("symbol"))
3167       .def_static("set_visibility", &PySymbolTable::setVisibility,
3168                   py::arg("symbol"), py::arg("visibility"))
3169       .def_static("replace_all_symbol_uses",
3170                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3171                   py::arg("new_symbol"), py::arg("from_op"))
3172       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3173                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3174                   py::arg("callback"));
3175 
3176   // Container bindings.
3177   PyBlockArgumentList::bind(m);
3178   PyBlockIterator::bind(m);
3179   PyBlockList::bind(m);
3180   PyOperationIterator::bind(m);
3181   PyOperationList::bind(m);
3182   PyOpAttributeMap::bind(m);
3183   PyOpOperandList::bind(m);
3184   PyOpResultList::bind(m);
3185   PyRegionIterator::bind(m);
3186   PyRegionList::bind(m);
3187 
3188   // Debug bindings.
3189   PyGlobalDebugFlag::bind(m);
3190 }
3191