1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 //#include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 
23 #include <utility>
24 
25 namespace py = pybind11;
26 using namespace mlir;
27 using namespace mlir::python;
28 
29 using llvm::SmallVector;
30 using llvm::StringRef;
31 using llvm::Twine;
32 
33 //------------------------------------------------------------------------------
34 // Docstrings (trivial, non-duplicated docstrings are included inline).
35 //------------------------------------------------------------------------------
36 
37 static const char kContextParseTypeDocstring[] =
38     R"(Parses the assembly form of a type.
39 
40 Returns a Type object or raises a ValueError if the type cannot be parsed.
41 
42 See also: https://mlir.llvm.org/docs/LangRef/#type-system
43 )";
44 
45 static const char kContextGetCallSiteLocationDocstring[] =
46     R"(Gets a Location representing a caller and callsite)";
47 
48 static const char kContextGetFileLocationDocstring[] =
49     R"(Gets a Location representing a file, line and column)";
50 
51 static const char kContextGetFusedLocationDocstring[] =
52     R"(Gets a Location representing a fused location with optional metadata)";
53 
54 static const char kContextGetNameLocationDocString[] =
55     R"(Gets a Location representing a named location with optional child location)";
56 
57 static const char kModuleParseDocstring[] =
58     R"(Parses a module's assembly format from a string.
59 
60 Returns a new MlirModule or raises a ValueError if the parsing fails.
61 
62 See also: https://mlir.llvm.org/docs/LangRef/
63 )";
64 
65 static const char kOperationCreateDocstring[] =
66     R"(Creates a new operation.
67 
68 Args:
69   name: Operation name (e.g. "dialect.operation").
70   results: Sequence of Type representing op result types.
71   attributes: Dict of str:Attribute.
72   successors: List of Block for the operation's successors.
73   regions: Number of regions to create.
74   location: A Location object (defaults to resolve from context manager).
75   ip: An InsertionPoint (defaults to resolve from context manager or set to
76     False to disable insertion, even with an insertion point set in the
77     context manager).
78 Returns:
79   A new "detached" Operation object. Detached operations can be added
80   to blocks, which causes them to become "attached."
81 )";
82 
83 static const char kOperationPrintDocstring[] =
84     R"(Prints the assembly form of the operation to a file like object.
85 
86 Args:
87   file: The file like object to write to. Defaults to sys.stdout.
88   binary: Whether to write bytes (True) or str (False). Defaults to False.
89   large_elements_limit: Whether to elide elements attributes above this
90     number of elements. Defaults to None (no limit).
91   enable_debug_info: Whether to print debug/location information. Defaults
92     to False.
93   pretty_debug_info: Whether to format debug information for easier reading
94     by a human (warning: the result is unparseable).
95   print_generic_op_form: Whether to print the generic assembly forms of all
96     ops. Defaults to False.
97   use_local_Scope: Whether to print in a way that is more optimized for
98     multi-threaded access but may not be consistent with how the overall
99     module prints.
100   assume_verified: By default, if not printing generic form, the verifier
101     will be run and if it fails, generic form will be printed with a comment
102     about failed verification. While a reasonable default for interactive use,
103     for systematic use, it is often better for the caller to verify explicitly
104     and report failures in a more robust fashion. Set this to True if doing this
105     in order to avoid running a redundant verification. If the IR is actually
106     invalid, behavior is undefined.
107 )";
108 
109 static const char kOperationGetAsmDocstring[] =
110     R"(Gets the assembly form of the operation with all options available.
111 
112 Args:
113   binary: Whether to return a bytes (True) or str (False) object. Defaults to
114     False.
115   ... others ...: See the print() method for common keyword arguments for
116     configuring the printout.
117 Returns:
118   Either a bytes or str object, depending on the setting of the 'binary'
119   argument.
120 )";
121 
122 static const char kOperationStrDunderDocstring[] =
123     R"(Gets the assembly form of the operation with default options.
124 
125 If more advanced control over the assembly formatting or I/O options is needed,
126 use the dedicated print or get_asm method, which supports keyword arguments to
127 customize behavior.
128 )";
129 
130 static const char kDumpDocstring[] =
131     R"(Dumps a debug representation of the object to stderr.)";
132 
133 static const char kAppendBlockDocstring[] =
134     R"(Appends a new block, with argument types as positional args.
135 
136 Returns:
137   The created block.
138 )";
139 
140 static const char kValueDunderStrDocstring[] =
141     R"(Returns the string form of the value.
142 
143 If the value is a block argument, this is the assembly form of its type and the
144 position in the argument list. If the value is an operation result, this is
145 equivalent to printing the operation that produced it.
146 )";
147 
148 //------------------------------------------------------------------------------
149 // Utilities.
150 //------------------------------------------------------------------------------
151 
152 /// Helper for creating an @classmethod.
153 template <class Func, typename... Args>
154 py::object classmethod(Func f, Args... args) {
155   py::object cf = py::cpp_function(f, args...);
156   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
157 }
158 
159 static py::object
160 createCustomDialectWrapper(const std::string &dialectNamespace,
161                            py::object dialectDescriptor) {
162   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
163   if (!dialectClass) {
164     // Use the base class.
165     return py::cast(PyDialect(std::move(dialectDescriptor)));
166   }
167 
168   // Create the custom implementation.
169   return (*dialectClass)(std::move(dialectDescriptor));
170 }
171 
172 static MlirStringRef toMlirStringRef(const std::string &s) {
173   return mlirStringRefCreate(s.data(), s.size());
174 }
175 
176 /// Wrapper for the global LLVM debugging flag.
177 struct PyGlobalDebugFlag {
178   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
179 
180   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
181 
182   static void bind(py::module &m) {
183     // Debug flags.
184     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
185         .def_property_static("flag", &PyGlobalDebugFlag::get,
186                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
187   }
188 };
189 
190 //------------------------------------------------------------------------------
191 // Collections.
192 //------------------------------------------------------------------------------
193 
194 namespace {
195 
196 class PyRegionIterator {
197 public:
198   PyRegionIterator(PyOperationRef operation)
199       : operation(std::move(operation)) {}
200 
201   PyRegionIterator &dunderIter() { return *this; }
202 
203   PyRegion dunderNext() {
204     operation->checkValid();
205     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
206       throw py::stop_iteration();
207     }
208     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
209     return PyRegion(operation, region);
210   }
211 
212   static void bind(py::module &m) {
213     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
214         .def("__iter__", &PyRegionIterator::dunderIter)
215         .def("__next__", &PyRegionIterator::dunderNext);
216   }
217 
218 private:
219   PyOperationRef operation;
220   int nextIndex = 0;
221 };
222 
223 /// Regions of an op are fixed length and indexed numerically so are represented
224 /// with a sequence-like container.
225 class PyRegionList {
226 public:
227   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
228 
229   intptr_t dunderLen() {
230     operation->checkValid();
231     return mlirOperationGetNumRegions(operation->get());
232   }
233 
234   PyRegion dunderGetItem(intptr_t index) {
235     // dunderLen checks validity.
236     if (index < 0 || index >= dunderLen()) {
237       throw SetPyError(PyExc_IndexError,
238                        "attempt to access out of bounds region");
239     }
240     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
241     return PyRegion(operation, region);
242   }
243 
244   static void bind(py::module &m) {
245     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
246         .def("__len__", &PyRegionList::dunderLen)
247         .def("__getitem__", &PyRegionList::dunderGetItem);
248   }
249 
250 private:
251   PyOperationRef operation;
252 };
253 
254 class PyBlockIterator {
255 public:
256   PyBlockIterator(PyOperationRef operation, MlirBlock next)
257       : operation(std::move(operation)), next(next) {}
258 
259   PyBlockIterator &dunderIter() { return *this; }
260 
261   PyBlock dunderNext() {
262     operation->checkValid();
263     if (mlirBlockIsNull(next)) {
264       throw py::stop_iteration();
265     }
266 
267     PyBlock returnBlock(operation, next);
268     next = mlirBlockGetNextInRegion(next);
269     return returnBlock;
270   }
271 
272   static void bind(py::module &m) {
273     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
274         .def("__iter__", &PyBlockIterator::dunderIter)
275         .def("__next__", &PyBlockIterator::dunderNext);
276   }
277 
278 private:
279   PyOperationRef operation;
280   MlirBlock next;
281 };
282 
283 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
284 /// we present them as a more full-featured list-like container but optimize
285 /// it for forward iteration. Blocks are always owned by a region.
286 class PyBlockList {
287 public:
288   PyBlockList(PyOperationRef operation, MlirRegion region)
289       : operation(std::move(operation)), region(region) {}
290 
291   PyBlockIterator dunderIter() {
292     operation->checkValid();
293     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
294   }
295 
296   intptr_t dunderLen() {
297     operation->checkValid();
298     intptr_t count = 0;
299     MlirBlock block = mlirRegionGetFirstBlock(region);
300     while (!mlirBlockIsNull(block)) {
301       count += 1;
302       block = mlirBlockGetNextInRegion(block);
303     }
304     return count;
305   }
306 
307   PyBlock dunderGetItem(intptr_t index) {
308     operation->checkValid();
309     if (index < 0) {
310       throw SetPyError(PyExc_IndexError,
311                        "attempt to access out of bounds block");
312     }
313     MlirBlock block = mlirRegionGetFirstBlock(region);
314     while (!mlirBlockIsNull(block)) {
315       if (index == 0) {
316         return PyBlock(operation, block);
317       }
318       block = mlirBlockGetNextInRegion(block);
319       index -= 1;
320     }
321     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
322   }
323 
324   PyBlock appendBlock(const py::args &pyArgTypes) {
325     operation->checkValid();
326     llvm::SmallVector<MlirType, 4> argTypes;
327     llvm::SmallVector<MlirLocation, 4> argLocs;
328     argTypes.reserve(pyArgTypes.size());
329     argLocs.reserve(pyArgTypes.size());
330     for (auto &pyArg : pyArgTypes) {
331       argTypes.push_back(pyArg.cast<PyType &>());
332       // TODO: Pass in a proper location here.
333       argLocs.push_back(
334           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
335     }
336 
337     MlirBlock block =
338         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
339     mlirRegionAppendOwnedBlock(region, block);
340     return PyBlock(operation, block);
341   }
342 
343   static void bind(py::module &m) {
344     py::class_<PyBlockList>(m, "BlockList", py::module_local())
345         .def("__getitem__", &PyBlockList::dunderGetItem)
346         .def("__iter__", &PyBlockList::dunderIter)
347         .def("__len__", &PyBlockList::dunderLen)
348         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
349   }
350 
351 private:
352   PyOperationRef operation;
353   MlirRegion region;
354 };
355 
356 class PyOperationIterator {
357 public:
358   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
359       : parentOperation(std::move(parentOperation)), next(next) {}
360 
361   PyOperationIterator &dunderIter() { return *this; }
362 
363   py::object dunderNext() {
364     parentOperation->checkValid();
365     if (mlirOperationIsNull(next)) {
366       throw py::stop_iteration();
367     }
368 
369     PyOperationRef returnOperation =
370         PyOperation::forOperation(parentOperation->getContext(), next);
371     next = mlirOperationGetNextInBlock(next);
372     return returnOperation->createOpView();
373   }
374 
375   static void bind(py::module &m) {
376     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
377         .def("__iter__", &PyOperationIterator::dunderIter)
378         .def("__next__", &PyOperationIterator::dunderNext);
379   }
380 
381 private:
382   PyOperationRef parentOperation;
383   MlirOperation next;
384 };
385 
386 /// Operations are exposed by the C-API as a forward-only linked list. In
387 /// Python, we present them as a more full-featured list-like container but
388 /// optimize it for forward iteration. Iterable operations are always owned
389 /// by a block.
390 class PyOperationList {
391 public:
392   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
393       : parentOperation(std::move(parentOperation)), block(block) {}
394 
395   PyOperationIterator dunderIter() {
396     parentOperation->checkValid();
397     return PyOperationIterator(parentOperation,
398                                mlirBlockGetFirstOperation(block));
399   }
400 
401   intptr_t dunderLen() {
402     parentOperation->checkValid();
403     intptr_t count = 0;
404     MlirOperation childOp = mlirBlockGetFirstOperation(block);
405     while (!mlirOperationIsNull(childOp)) {
406       count += 1;
407       childOp = mlirOperationGetNextInBlock(childOp);
408     }
409     return count;
410   }
411 
412   py::object dunderGetItem(intptr_t index) {
413     parentOperation->checkValid();
414     if (index < 0) {
415       throw SetPyError(PyExc_IndexError,
416                        "attempt to access out of bounds operation");
417     }
418     MlirOperation childOp = mlirBlockGetFirstOperation(block);
419     while (!mlirOperationIsNull(childOp)) {
420       if (index == 0) {
421         return PyOperation::forOperation(parentOperation->getContext(), childOp)
422             ->createOpView();
423       }
424       childOp = mlirOperationGetNextInBlock(childOp);
425       index -= 1;
426     }
427     throw SetPyError(PyExc_IndexError,
428                      "attempt to access out of bounds operation");
429   }
430 
431   static void bind(py::module &m) {
432     py::class_<PyOperationList>(m, "OperationList", py::module_local())
433         .def("__getitem__", &PyOperationList::dunderGetItem)
434         .def("__iter__", &PyOperationList::dunderIter)
435         .def("__len__", &PyOperationList::dunderLen);
436   }
437 
438 private:
439   PyOperationRef parentOperation;
440   MlirBlock block;
441 };
442 
443 } // namespace
444 
445 //------------------------------------------------------------------------------
446 // PyMlirContext
447 //------------------------------------------------------------------------------
448 
449 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
450   py::gil_scoped_acquire acquire;
451   auto &liveContexts = getLiveContexts();
452   liveContexts[context.ptr] = this;
453 }
454 
455 PyMlirContext::~PyMlirContext() {
456   // Note that the only public way to construct an instance is via the
457   // forContext method, which always puts the associated handle into
458   // liveContexts.
459   py::gil_scoped_acquire acquire;
460   getLiveContexts().erase(context.ptr);
461   mlirContextDestroy(context);
462 }
463 
464 py::object PyMlirContext::getCapsule() {
465   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
466 }
467 
468 py::object PyMlirContext::createFromCapsule(py::object capsule) {
469   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
470   if (mlirContextIsNull(rawContext))
471     throw py::error_already_set();
472   return forContext(rawContext).releaseObject();
473 }
474 
475 PyMlirContext *PyMlirContext::createNewContextForInit() {
476   MlirContext context = mlirContextCreate();
477   return new PyMlirContext(context);
478 }
479 
480 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
481   py::gil_scoped_acquire acquire;
482   auto &liveContexts = getLiveContexts();
483   auto it = liveContexts.find(context.ptr);
484   if (it == liveContexts.end()) {
485     // Create.
486     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
487     py::object pyRef = py::cast(unownedContextWrapper);
488     assert(pyRef && "cast to py::object failed");
489     liveContexts[context.ptr] = unownedContextWrapper;
490     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
491   }
492   // Use existing.
493   py::object pyRef = py::cast(it->second);
494   return PyMlirContextRef(it->second, std::move(pyRef));
495 }
496 
497 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
498   static LiveContextMap liveContexts;
499   return liveContexts;
500 }
501 
502 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
503 
504 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
505 
506 size_t PyMlirContext::clearLiveOperations() {
507   for (auto &op : liveOperations)
508     op.second.second->setInvalid();
509   size_t numInvalidated = liveOperations.size();
510   liveOperations.clear();
511   return numInvalidated;
512 }
513 
514 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
515 
516 pybind11::object PyMlirContext::contextEnter() {
517   return PyThreadContextEntry::pushContext(*this);
518 }
519 
520 void PyMlirContext::contextExit(const pybind11::object &excType,
521                                 const pybind11::object &excVal,
522                                 const pybind11::object &excTb) {
523   PyThreadContextEntry::popContext(*this);
524 }
525 
526 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
527   // Note that ownership is transferred to the delete callback below by way of
528   // an explicit inc_ref (borrow).
529   PyDiagnosticHandler *pyHandler =
530       new PyDiagnosticHandler(get(), std::move(callback));
531   py::object pyHandlerObject =
532       py::cast(pyHandler, py::return_value_policy::take_ownership);
533   pyHandlerObject.inc_ref();
534 
535   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
536   // guaranteed to be known to pybind.
537   auto handlerCallback =
538       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
539     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
540     py::object pyDiagnosticObject =
541         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
542 
543     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
544     bool result = false;
545     {
546       // Since this can be called from arbitrary C++ contexts, always get the
547       // gil.
548       py::gil_scoped_acquire gil;
549       try {
550         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
551       } catch (std::exception &e) {
552         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
553                 e.what());
554         pyHandler->hadError = true;
555       }
556     }
557 
558     pyDiagnostic->invalidate();
559     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
560   };
561   auto deleteCallback = +[](void *userData) {
562     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
563     assert(pyHandler->registeredID && "handler is not registered");
564     pyHandler->registeredID.reset();
565 
566     // Decrement reference, balancing the inc_ref() above.
567     py::object pyHandlerObject =
568         py::cast(pyHandler, py::return_value_policy::reference);
569     pyHandlerObject.dec_ref();
570   };
571 
572   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
573       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
574   return pyHandlerObject;
575 }
576 
577 PyMlirContext &DefaultingPyMlirContext::resolve() {
578   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
579   if (!context) {
580     throw SetPyError(
581         PyExc_RuntimeError,
582         "An MLIR function requires a Context but none was provided in the call "
583         "or from the surrounding environment. Either pass to the function with "
584         "a 'context=' argument or establish a default using 'with Context():'");
585   }
586   return *context;
587 }
588 
589 //------------------------------------------------------------------------------
590 // PyThreadContextEntry management
591 //------------------------------------------------------------------------------
592 
593 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
594   static thread_local std::vector<PyThreadContextEntry> stack;
595   return stack;
596 }
597 
598 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
599   auto &stack = getStack();
600   if (stack.empty())
601     return nullptr;
602   return &stack.back();
603 }
604 
605 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
606                                 py::object insertionPoint,
607                                 py::object location) {
608   auto &stack = getStack();
609   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
610                      std::move(location));
611   // If the new stack has more than one entry and the context of the new top
612   // entry matches the previous, copy the insertionPoint and location from the
613   // previous entry if missing from the new top entry.
614   if (stack.size() > 1) {
615     auto &prev = *(stack.rbegin() + 1);
616     auto &current = stack.back();
617     if (current.context.is(prev.context)) {
618       // Default non-context objects from the previous entry.
619       if (!current.insertionPoint)
620         current.insertionPoint = prev.insertionPoint;
621       if (!current.location)
622         current.location = prev.location;
623     }
624   }
625 }
626 
627 PyMlirContext *PyThreadContextEntry::getContext() {
628   if (!context)
629     return nullptr;
630   return py::cast<PyMlirContext *>(context);
631 }
632 
633 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
634   if (!insertionPoint)
635     return nullptr;
636   return py::cast<PyInsertionPoint *>(insertionPoint);
637 }
638 
639 PyLocation *PyThreadContextEntry::getLocation() {
640   if (!location)
641     return nullptr;
642   return py::cast<PyLocation *>(location);
643 }
644 
645 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
646   auto *tos = getTopOfStack();
647   return tos ? tos->getContext() : nullptr;
648 }
649 
650 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
651   auto *tos = getTopOfStack();
652   return tos ? tos->getInsertionPoint() : nullptr;
653 }
654 
655 PyLocation *PyThreadContextEntry::getDefaultLocation() {
656   auto *tos = getTopOfStack();
657   return tos ? tos->getLocation() : nullptr;
658 }
659 
660 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
661   py::object contextObj = py::cast(context);
662   push(FrameKind::Context, /*context=*/contextObj,
663        /*insertionPoint=*/py::object(),
664        /*location=*/py::object());
665   return contextObj;
666 }
667 
668 void PyThreadContextEntry::popContext(PyMlirContext &context) {
669   auto &stack = getStack();
670   if (stack.empty())
671     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
672   auto &tos = stack.back();
673   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
674     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
675   stack.pop_back();
676 }
677 
678 py::object
679 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
680   py::object contextObj =
681       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
682   py::object insertionPointObj = py::cast(insertionPoint);
683   push(FrameKind::InsertionPoint,
684        /*context=*/contextObj,
685        /*insertionPoint=*/insertionPointObj,
686        /*location=*/py::object());
687   return insertionPointObj;
688 }
689 
690 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
691   auto &stack = getStack();
692   if (stack.empty())
693     throw SetPyError(PyExc_RuntimeError,
694                      "Unbalanced InsertionPoint enter/exit");
695   auto &tos = stack.back();
696   if (tos.frameKind != FrameKind::InsertionPoint &&
697       tos.getInsertionPoint() != &insertionPoint)
698     throw SetPyError(PyExc_RuntimeError,
699                      "Unbalanced InsertionPoint enter/exit");
700   stack.pop_back();
701 }
702 
703 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
704   py::object contextObj = location.getContext().getObject();
705   py::object locationObj = py::cast(location);
706   push(FrameKind::Location, /*context=*/contextObj,
707        /*insertionPoint=*/py::object(),
708        /*location=*/locationObj);
709   return locationObj;
710 }
711 
712 void PyThreadContextEntry::popLocation(PyLocation &location) {
713   auto &stack = getStack();
714   if (stack.empty())
715     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
716   auto &tos = stack.back();
717   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
718     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
719   stack.pop_back();
720 }
721 
722 //------------------------------------------------------------------------------
723 // PyDiagnostic*
724 //------------------------------------------------------------------------------
725 
726 void PyDiagnostic::invalidate() {
727   valid = false;
728   if (materializedNotes) {
729     for (auto &noteObject : *materializedNotes) {
730       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
731       note->invalidate();
732     }
733   }
734 }
735 
736 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
737                                          py::object callback)
738     : context(context), callback(std::move(callback)) {}
739 
740 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
741 
742 void PyDiagnosticHandler::detach() {
743   if (!registeredID)
744     return;
745   MlirDiagnosticHandlerID localID = *registeredID;
746   mlirContextDetachDiagnosticHandler(context, localID);
747   assert(!registeredID && "should have unregistered");
748   // Not strictly necessary but keeps stale pointers from being around to cause
749   // issues.
750   context = {nullptr};
751 }
752 
753 void PyDiagnostic::checkValid() {
754   if (!valid) {
755     throw std::invalid_argument(
756         "Diagnostic is invalid (used outside of callback)");
757   }
758 }
759 
760 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
761   checkValid();
762   return mlirDiagnosticGetSeverity(diagnostic);
763 }
764 
765 PyLocation PyDiagnostic::getLocation() {
766   checkValid();
767   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
768   MlirContext context = mlirLocationGetContext(loc);
769   return PyLocation(PyMlirContext::forContext(context), loc);
770 }
771 
772 py::str PyDiagnostic::getMessage() {
773   checkValid();
774   py::object fileObject = py::module::import("io").attr("StringIO")();
775   PyFileAccumulator accum(fileObject, /*binary=*/false);
776   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
777   return fileObject.attr("getvalue")();
778 }
779 
780 py::tuple PyDiagnostic::getNotes() {
781   checkValid();
782   if (materializedNotes)
783     return *materializedNotes;
784   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
785   materializedNotes = py::tuple(numNotes);
786   for (intptr_t i = 0; i < numNotes; ++i) {
787     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
788     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
789     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
790   }
791   return *materializedNotes;
792 }
793 
794 //------------------------------------------------------------------------------
795 // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry
796 //------------------------------------------------------------------------------
797 
798 MlirDialect PyDialects::getDialectForKey(const std::string &key,
799                                          bool attrError) {
800   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
801                                                     {key.data(), key.size()});
802   if (mlirDialectIsNull(dialect)) {
803     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
804                      Twine("Dialect '") + key + "' not found");
805   }
806   return dialect;
807 }
808 
809 py::object PyDialectRegistry::getCapsule() {
810   return py::reinterpret_steal<py::object>(
811       mlirPythonDialectRegistryToCapsule(*this));
812 }
813 
814 PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) {
815   MlirDialectRegistry rawRegistry =
816       mlirPythonCapsuleToDialectRegistry(capsule.ptr());
817   if (mlirDialectRegistryIsNull(rawRegistry))
818     throw py::error_already_set();
819   return PyDialectRegistry(rawRegistry);
820 }
821 
822 //------------------------------------------------------------------------------
823 // PyLocation
824 //------------------------------------------------------------------------------
825 
826 py::object PyLocation::getCapsule() {
827   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
828 }
829 
830 PyLocation PyLocation::createFromCapsule(py::object capsule) {
831   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
832   if (mlirLocationIsNull(rawLoc))
833     throw py::error_already_set();
834   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
835                     rawLoc);
836 }
837 
838 py::object PyLocation::contextEnter() {
839   return PyThreadContextEntry::pushLocation(*this);
840 }
841 
842 void PyLocation::contextExit(const pybind11::object &excType,
843                              const pybind11::object &excVal,
844                              const pybind11::object &excTb) {
845   PyThreadContextEntry::popLocation(*this);
846 }
847 
848 PyLocation &DefaultingPyLocation::resolve() {
849   auto *location = PyThreadContextEntry::getDefaultLocation();
850   if (!location) {
851     throw SetPyError(
852         PyExc_RuntimeError,
853         "An MLIR function requires a Location but none was provided in the "
854         "call or from the surrounding environment. Either pass to the function "
855         "with a 'loc=' argument or establish a default using 'with loc:'");
856   }
857   return *location;
858 }
859 
860 //------------------------------------------------------------------------------
861 // PyModule
862 //------------------------------------------------------------------------------
863 
864 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
865     : BaseContextObject(std::move(contextRef)), module(module) {}
866 
867 PyModule::~PyModule() {
868   py::gil_scoped_acquire acquire;
869   auto &liveModules = getContext()->liveModules;
870   assert(liveModules.count(module.ptr) == 1 &&
871          "destroying module not in live map");
872   liveModules.erase(module.ptr);
873   mlirModuleDestroy(module);
874 }
875 
876 PyModuleRef PyModule::forModule(MlirModule module) {
877   MlirContext context = mlirModuleGetContext(module);
878   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
879 
880   py::gil_scoped_acquire acquire;
881   auto &liveModules = contextRef->liveModules;
882   auto it = liveModules.find(module.ptr);
883   if (it == liveModules.end()) {
884     // Create.
885     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
886     // Note that the default return value policy on cast is automatic_reference,
887     // which does not take ownership (delete will not be called).
888     // Just be explicit.
889     py::object pyRef =
890         py::cast(unownedModule, py::return_value_policy::take_ownership);
891     unownedModule->handle = pyRef;
892     liveModules[module.ptr] =
893         std::make_pair(unownedModule->handle, unownedModule);
894     return PyModuleRef(unownedModule, std::move(pyRef));
895   }
896   // Use existing.
897   PyModule *existing = it->second.second;
898   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
899   return PyModuleRef(existing, std::move(pyRef));
900 }
901 
902 py::object PyModule::createFromCapsule(py::object capsule) {
903   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
904   if (mlirModuleIsNull(rawModule))
905     throw py::error_already_set();
906   return forModule(rawModule).releaseObject();
907 }
908 
909 py::object PyModule::getCapsule() {
910   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
911 }
912 
913 //------------------------------------------------------------------------------
914 // PyOperation
915 //------------------------------------------------------------------------------
916 
917 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
918     : BaseContextObject(std::move(contextRef)), operation(operation) {}
919 
920 PyOperation::~PyOperation() {
921   // If the operation has already been invalidated there is nothing to do.
922   if (!valid)
923     return;
924   auto &liveOperations = getContext()->liveOperations;
925   assert(liveOperations.count(operation.ptr) == 1 &&
926          "destroying operation not in live map");
927   liveOperations.erase(operation.ptr);
928   if (!isAttached()) {
929     mlirOperationDestroy(operation);
930   }
931 }
932 
933 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
934                                            MlirOperation operation,
935                                            py::object parentKeepAlive) {
936   auto &liveOperations = contextRef->liveOperations;
937   // Create.
938   PyOperation *unownedOperation =
939       new PyOperation(std::move(contextRef), operation);
940   // Note that the default return value policy on cast is automatic_reference,
941   // which does not take ownership (delete will not be called).
942   // Just be explicit.
943   py::object pyRef =
944       py::cast(unownedOperation, py::return_value_policy::take_ownership);
945   unownedOperation->handle = pyRef;
946   if (parentKeepAlive) {
947     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
948   }
949   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
950   return PyOperationRef(unownedOperation, std::move(pyRef));
951 }
952 
953 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
954                                          MlirOperation operation,
955                                          py::object parentKeepAlive) {
956   auto &liveOperations = contextRef->liveOperations;
957   auto it = liveOperations.find(operation.ptr);
958   if (it == liveOperations.end()) {
959     // Create.
960     return createInstance(std::move(contextRef), operation,
961                           std::move(parentKeepAlive));
962   }
963   // Use existing.
964   PyOperation *existing = it->second.second;
965   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
966   return PyOperationRef(existing, std::move(pyRef));
967 }
968 
969 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
970                                            MlirOperation operation,
971                                            py::object parentKeepAlive) {
972   auto &liveOperations = contextRef->liveOperations;
973   assert(liveOperations.count(operation.ptr) == 0 &&
974          "cannot create detached operation that already exists");
975   (void)liveOperations;
976 
977   PyOperationRef created = createInstance(std::move(contextRef), operation,
978                                           std::move(parentKeepAlive));
979   created->attached = false;
980   return created;
981 }
982 
983 void PyOperation::checkValid() const {
984   if (!valid) {
985     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
986   }
987 }
988 
989 void PyOperationBase::print(py::object fileObject, bool binary,
990                             llvm::Optional<int64_t> largeElementsLimit,
991                             bool enableDebugInfo, bool prettyDebugInfo,
992                             bool printGenericOpForm, bool useLocalScope,
993                             bool assumeVerified) {
994   PyOperation &operation = getOperation();
995   operation.checkValid();
996   if (fileObject.is_none())
997     fileObject = py::module::import("sys").attr("stdout");
998 
999   if (!assumeVerified && !printGenericOpForm &&
1000       !mlirOperationVerify(operation)) {
1001     std::string message("// Verification failed, printing generic form\n");
1002     if (binary) {
1003       fileObject.attr("write")(py::bytes(message));
1004     } else {
1005       fileObject.attr("write")(py::str(message));
1006     }
1007     printGenericOpForm = true;
1008   }
1009 
1010   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1011   if (largeElementsLimit)
1012     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1013   if (enableDebugInfo)
1014     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1015   if (printGenericOpForm)
1016     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1017   if (useLocalScope)
1018     mlirOpPrintingFlagsUseLocalScope(flags);
1019 
1020   PyFileAccumulator accum(fileObject, binary);
1021   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1022                               accum.getUserData());
1023   mlirOpPrintingFlagsDestroy(flags);
1024 }
1025 
1026 py::object PyOperationBase::getAsm(bool binary,
1027                                    llvm::Optional<int64_t> largeElementsLimit,
1028                                    bool enableDebugInfo, bool prettyDebugInfo,
1029                                    bool printGenericOpForm, bool useLocalScope,
1030                                    bool assumeVerified) {
1031   py::object fileObject;
1032   if (binary) {
1033     fileObject = py::module::import("io").attr("BytesIO")();
1034   } else {
1035     fileObject = py::module::import("io").attr("StringIO")();
1036   }
1037   print(fileObject, /*binary=*/binary,
1038         /*largeElementsLimit=*/largeElementsLimit,
1039         /*enableDebugInfo=*/enableDebugInfo,
1040         /*prettyDebugInfo=*/prettyDebugInfo,
1041         /*printGenericOpForm=*/printGenericOpForm,
1042         /*useLocalScope=*/useLocalScope,
1043         /*assumeVerified=*/assumeVerified);
1044 
1045   return fileObject.attr("getvalue")();
1046 }
1047 
1048 void PyOperationBase::moveAfter(PyOperationBase &other) {
1049   PyOperation &operation = getOperation();
1050   PyOperation &otherOp = other.getOperation();
1051   operation.checkValid();
1052   otherOp.checkValid();
1053   mlirOperationMoveAfter(operation, otherOp);
1054   operation.parentKeepAlive = otherOp.parentKeepAlive;
1055 }
1056 
1057 void PyOperationBase::moveBefore(PyOperationBase &other) {
1058   PyOperation &operation = getOperation();
1059   PyOperation &otherOp = other.getOperation();
1060   operation.checkValid();
1061   otherOp.checkValid();
1062   mlirOperationMoveBefore(operation, otherOp);
1063   operation.parentKeepAlive = otherOp.parentKeepAlive;
1064 }
1065 
1066 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1067   checkValid();
1068   if (!isAttached())
1069     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1070   MlirOperation operation = mlirOperationGetParentOperation(get());
1071   if (mlirOperationIsNull(operation))
1072     return {};
1073   return PyOperation::forOperation(getContext(), operation);
1074 }
1075 
1076 PyBlock PyOperation::getBlock() {
1077   checkValid();
1078   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1079   MlirBlock block = mlirOperationGetBlock(get());
1080   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1081   assert(parentOperation && "Operation has no parent");
1082   return PyBlock{std::move(*parentOperation), block};
1083 }
1084 
1085 py::object PyOperation::getCapsule() {
1086   checkValid();
1087   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1088 }
1089 
1090 py::object PyOperation::createFromCapsule(py::object capsule) {
1091   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1092   if (mlirOperationIsNull(rawOperation))
1093     throw py::error_already_set();
1094   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1095   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1096       .releaseObject();
1097 }
1098 
1099 static void maybeInsertOperation(PyOperationRef &op,
1100                                  const py::object &maybeIp) {
1101   // InsertPoint active?
1102   if (!maybeIp.is(py::cast(false))) {
1103     PyInsertionPoint *ip;
1104     if (maybeIp.is_none()) {
1105       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1106     } else {
1107       ip = py::cast<PyInsertionPoint *>(maybeIp);
1108     }
1109     if (ip)
1110       ip->insert(*op.get());
1111   }
1112 }
1113 
1114 py::object PyOperation::create(
1115     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1116     llvm::Optional<std::vector<PyValue *>> operands,
1117     llvm::Optional<py::dict> attributes,
1118     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1119     DefaultingPyLocation location, const py::object &maybeIp) {
1120   llvm::SmallVector<MlirValue, 4> mlirOperands;
1121   llvm::SmallVector<MlirType, 4> mlirResults;
1122   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1123   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1124 
1125   // General parameter validation.
1126   if (regions < 0)
1127     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1128 
1129   // Unpack/validate operands.
1130   if (operands) {
1131     mlirOperands.reserve(operands->size());
1132     for (PyValue *operand : *operands) {
1133       if (!operand)
1134         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1135       mlirOperands.push_back(operand->get());
1136     }
1137   }
1138 
1139   // Unpack/validate results.
1140   if (results) {
1141     mlirResults.reserve(results->size());
1142     for (PyType *result : *results) {
1143       // TODO: Verify result type originate from the same context.
1144       if (!result)
1145         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1146       mlirResults.push_back(*result);
1147     }
1148   }
1149   // Unpack/validate attributes.
1150   if (attributes) {
1151     mlirAttributes.reserve(attributes->size());
1152     for (auto &it : *attributes) {
1153       std::string key;
1154       try {
1155         key = it.first.cast<std::string>();
1156       } catch (py::cast_error &err) {
1157         std::string msg = "Invalid attribute key (not a string) when "
1158                           "attempting to create the operation \"" +
1159                           name + "\" (" + err.what() + ")";
1160         throw py::cast_error(msg);
1161       }
1162       try {
1163         auto &attribute = it.second.cast<PyAttribute &>();
1164         // TODO: Verify attribute originates from the same context.
1165         mlirAttributes.emplace_back(std::move(key), attribute);
1166       } catch (py::reference_cast_error &) {
1167         // This exception seems thrown when the value is "None".
1168         std::string msg =
1169             "Found an invalid (`None`?) attribute value for the key \"" + key +
1170             "\" when attempting to create the operation \"" + name + "\"";
1171         throw py::cast_error(msg);
1172       } catch (py::cast_error &err) {
1173         std::string msg = "Invalid attribute value for the key \"" + key +
1174                           "\" when attempting to create the operation \"" +
1175                           name + "\" (" + err.what() + ")";
1176         throw py::cast_error(msg);
1177       }
1178     }
1179   }
1180   // Unpack/validate successors.
1181   if (successors) {
1182     mlirSuccessors.reserve(successors->size());
1183     for (auto *successor : *successors) {
1184       // TODO: Verify successor originate from the same context.
1185       if (!successor)
1186         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1187       mlirSuccessors.push_back(successor->get());
1188     }
1189   }
1190 
1191   // Apply unpacked/validated to the operation state. Beyond this
1192   // point, exceptions cannot be thrown or else the state will leak.
1193   MlirOperationState state =
1194       mlirOperationStateGet(toMlirStringRef(name), location);
1195   if (!mlirOperands.empty())
1196     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1197                                   mlirOperands.data());
1198   if (!mlirResults.empty())
1199     mlirOperationStateAddResults(&state, mlirResults.size(),
1200                                  mlirResults.data());
1201   if (!mlirAttributes.empty()) {
1202     // Note that the attribute names directly reference bytes in
1203     // mlirAttributes, so that vector must not be changed from here
1204     // on.
1205     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1206     mlirNamedAttributes.reserve(mlirAttributes.size());
1207     for (auto &it : mlirAttributes)
1208       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1209           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1210                             toMlirStringRef(it.first)),
1211           it.second));
1212     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1213                                     mlirNamedAttributes.data());
1214   }
1215   if (!mlirSuccessors.empty())
1216     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1217                                     mlirSuccessors.data());
1218   if (regions) {
1219     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1220     mlirRegions.resize(regions);
1221     for (int i = 0; i < regions; ++i)
1222       mlirRegions[i] = mlirRegionCreate();
1223     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1224                                       mlirRegions.data());
1225   }
1226 
1227   // Construct the operation.
1228   MlirOperation operation = mlirOperationCreate(&state);
1229   PyOperationRef created =
1230       PyOperation::createDetached(location->getContext(), operation);
1231   maybeInsertOperation(created, maybeIp);
1232 
1233   return created->createOpView();
1234 }
1235 
1236 py::object PyOperation::clone(const py::object &maybeIp) {
1237   MlirOperation clonedOperation = mlirOperationClone(operation);
1238   PyOperationRef cloned =
1239       PyOperation::createDetached(getContext(), clonedOperation);
1240   maybeInsertOperation(cloned, maybeIp);
1241 
1242   return cloned->createOpView();
1243 }
1244 
1245 py::object PyOperation::createOpView() {
1246   checkValid();
1247   MlirIdentifier ident = mlirOperationGetName(get());
1248   MlirStringRef identStr = mlirIdentifierStr(ident);
1249   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1250       StringRef(identStr.data, identStr.length));
1251   if (opViewClass)
1252     return (*opViewClass)(getRef().getObject());
1253   return py::cast(PyOpView(getRef().getObject()));
1254 }
1255 
1256 void PyOperation::erase() {
1257   checkValid();
1258   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1259   // Python reference to a child operation is live. All children should also
1260   // have their `valid` bit set to false.
1261   auto &liveOperations = getContext()->liveOperations;
1262   if (liveOperations.count(operation.ptr))
1263     liveOperations.erase(operation.ptr);
1264   mlirOperationDestroy(operation);
1265   valid = false;
1266 }
1267 
1268 //------------------------------------------------------------------------------
1269 // PyOpView
1270 //------------------------------------------------------------------------------
1271 
1272 py::object PyOpView::buildGeneric(
1273     const py::object &cls, py::list resultTypeList, py::list operandList,
1274     llvm::Optional<py::dict> attributes,
1275     llvm::Optional<std::vector<PyBlock *>> successors,
1276     llvm::Optional<int> regions, DefaultingPyLocation location,
1277     const py::object &maybeIp) {
1278   PyMlirContextRef context = location->getContext();
1279   // Class level operation construction metadata.
1280   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1281   // Operand and result segment specs are either none, which does no
1282   // variadic unpacking, or a list of ints with segment sizes, where each
1283   // element is either a positive number (typically 1 for a scalar) or -1 to
1284   // indicate that it is derived from the length of the same-indexed operand
1285   // or result (implying that it is a list at that position).
1286   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1287   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1288 
1289   std::vector<uint32_t> operandSegmentLengths;
1290   std::vector<uint32_t> resultSegmentLengths;
1291 
1292   // Validate/determine region count.
1293   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1294   int opMinRegionCount = std::get<0>(opRegionSpec);
1295   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1296   if (!regions) {
1297     regions = opMinRegionCount;
1298   }
1299   if (*regions < opMinRegionCount) {
1300     throw py::value_error(
1301         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1302          llvm::Twine(opMinRegionCount) +
1303          " regions but was built with regions=" + llvm::Twine(*regions))
1304             .str());
1305   }
1306   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1307     throw py::value_error(
1308         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1309          llvm::Twine(opMinRegionCount) +
1310          " regions but was built with regions=" + llvm::Twine(*regions))
1311             .str());
1312   }
1313 
1314   // Unpack results.
1315   std::vector<PyType *> resultTypes;
1316   resultTypes.reserve(resultTypeList.size());
1317   if (resultSegmentSpecObj.is_none()) {
1318     // Non-variadic result unpacking.
1319     for (const auto &it : llvm::enumerate(resultTypeList)) {
1320       try {
1321         resultTypes.push_back(py::cast<PyType *>(it.value()));
1322         if (!resultTypes.back())
1323           throw py::cast_error();
1324       } catch (py::cast_error &err) {
1325         throw py::value_error((llvm::Twine("Result ") +
1326                                llvm::Twine(it.index()) + " of operation \"" +
1327                                name + "\" must be a Type (" + err.what() + ")")
1328                                   .str());
1329       }
1330     }
1331   } else {
1332     // Sized result unpacking.
1333     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1334     if (resultSegmentSpec.size() != resultTypeList.size()) {
1335       throw py::value_error((llvm::Twine("Operation \"") + name +
1336                              "\" requires " +
1337                              llvm::Twine(resultSegmentSpec.size()) +
1338                              " result segments but was provided " +
1339                              llvm::Twine(resultTypeList.size()))
1340                                 .str());
1341     }
1342     resultSegmentLengths.reserve(resultTypeList.size());
1343     for (const auto &it :
1344          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1345       int segmentSpec = std::get<1>(it.value());
1346       if (segmentSpec == 1 || segmentSpec == 0) {
1347         // Unpack unary element.
1348         try {
1349           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1350           if (resultType) {
1351             resultTypes.push_back(resultType);
1352             resultSegmentLengths.push_back(1);
1353           } else if (segmentSpec == 0) {
1354             // Allowed to be optional.
1355             resultSegmentLengths.push_back(0);
1356           } else {
1357             throw py::cast_error("was None and result is not optional");
1358           }
1359         } catch (py::cast_error &err) {
1360           throw py::value_error((llvm::Twine("Result ") +
1361                                  llvm::Twine(it.index()) + " of operation \"" +
1362                                  name + "\" must be a Type (" + err.what() +
1363                                  ")")
1364                                     .str());
1365         }
1366       } else if (segmentSpec == -1) {
1367         // Unpack sequence by appending.
1368         try {
1369           if (std::get<0>(it.value()).is_none()) {
1370             // Treat it as an empty list.
1371             resultSegmentLengths.push_back(0);
1372           } else {
1373             // Unpack the list.
1374             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1375             for (py::object segmentItem : segment) {
1376               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1377               if (!resultTypes.back()) {
1378                 throw py::cast_error("contained a None item");
1379               }
1380             }
1381             resultSegmentLengths.push_back(segment.size());
1382           }
1383         } catch (std::exception &err) {
1384           // NOTE: Sloppy to be using a catch-all here, but there are at least
1385           // three different unrelated exceptions that can be thrown in the
1386           // above "casts". Just keep the scope above small and catch them all.
1387           throw py::value_error((llvm::Twine("Result ") +
1388                                  llvm::Twine(it.index()) + " of operation \"" +
1389                                  name + "\" must be a Sequence of Types (" +
1390                                  err.what() + ")")
1391                                     .str());
1392         }
1393       } else {
1394         throw py::value_error("Unexpected segment spec");
1395       }
1396     }
1397   }
1398 
1399   // Unpack operands.
1400   std::vector<PyValue *> operands;
1401   operands.reserve(operands.size());
1402   if (operandSegmentSpecObj.is_none()) {
1403     // Non-sized operand unpacking.
1404     for (const auto &it : llvm::enumerate(operandList)) {
1405       try {
1406         operands.push_back(py::cast<PyValue *>(it.value()));
1407         if (!operands.back())
1408           throw py::cast_error();
1409       } catch (py::cast_error &err) {
1410         throw py::value_error((llvm::Twine("Operand ") +
1411                                llvm::Twine(it.index()) + " of operation \"" +
1412                                name + "\" must be a Value (" + err.what() + ")")
1413                                   .str());
1414       }
1415     }
1416   } else {
1417     // Sized operand unpacking.
1418     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1419     if (operandSegmentSpec.size() != operandList.size()) {
1420       throw py::value_error((llvm::Twine("Operation \"") + name +
1421                              "\" requires " +
1422                              llvm::Twine(operandSegmentSpec.size()) +
1423                              "operand segments but was provided " +
1424                              llvm::Twine(operandList.size()))
1425                                 .str());
1426     }
1427     operandSegmentLengths.reserve(operandList.size());
1428     for (const auto &it :
1429          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1430       int segmentSpec = std::get<1>(it.value());
1431       if (segmentSpec == 1 || segmentSpec == 0) {
1432         // Unpack unary element.
1433         try {
1434           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1435           if (operandValue) {
1436             operands.push_back(operandValue);
1437             operandSegmentLengths.push_back(1);
1438           } else if (segmentSpec == 0) {
1439             // Allowed to be optional.
1440             operandSegmentLengths.push_back(0);
1441           } else {
1442             throw py::cast_error("was None and operand is not optional");
1443           }
1444         } catch (py::cast_error &err) {
1445           throw py::value_error((llvm::Twine("Operand ") +
1446                                  llvm::Twine(it.index()) + " of operation \"" +
1447                                  name + "\" must be a Value (" + err.what() +
1448                                  ")")
1449                                     .str());
1450         }
1451       } else if (segmentSpec == -1) {
1452         // Unpack sequence by appending.
1453         try {
1454           if (std::get<0>(it.value()).is_none()) {
1455             // Treat it as an empty list.
1456             operandSegmentLengths.push_back(0);
1457           } else {
1458             // Unpack the list.
1459             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1460             for (py::object segmentItem : segment) {
1461               operands.push_back(py::cast<PyValue *>(segmentItem));
1462               if (!operands.back()) {
1463                 throw py::cast_error("contained a None item");
1464               }
1465             }
1466             operandSegmentLengths.push_back(segment.size());
1467           }
1468         } catch (std::exception &err) {
1469           // NOTE: Sloppy to be using a catch-all here, but there are at least
1470           // three different unrelated exceptions that can be thrown in the
1471           // above "casts". Just keep the scope above small and catch them all.
1472           throw py::value_error((llvm::Twine("Operand ") +
1473                                  llvm::Twine(it.index()) + " of operation \"" +
1474                                  name + "\" must be a Sequence of Values (" +
1475                                  err.what() + ")")
1476                                     .str());
1477         }
1478       } else {
1479         throw py::value_error("Unexpected segment spec");
1480       }
1481     }
1482   }
1483 
1484   // Merge operand/result segment lengths into attributes if needed.
1485   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1486     // Dup.
1487     if (attributes) {
1488       attributes = py::dict(*attributes);
1489     } else {
1490       attributes = py::dict();
1491     }
1492     if (attributes->contains("result_segment_sizes") ||
1493         attributes->contains("operand_segment_sizes")) {
1494       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1495                             "'operand_segment_sizes' attribute is unsupported. "
1496                             "Use Operation.create for such low-level access.");
1497     }
1498 
1499     // Add result_segment_sizes attribute.
1500     if (!resultSegmentLengths.empty()) {
1501       int64_t size = resultSegmentLengths.size();
1502       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1503           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1504           resultSegmentLengths.size(), resultSegmentLengths.data());
1505       (*attributes)["result_segment_sizes"] =
1506           PyAttribute(context, segmentLengthAttr);
1507     }
1508 
1509     // Add operand_segment_sizes attribute.
1510     if (!operandSegmentLengths.empty()) {
1511       int64_t size = operandSegmentLengths.size();
1512       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1513           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1514           operandSegmentLengths.size(), operandSegmentLengths.data());
1515       (*attributes)["operand_segment_sizes"] =
1516           PyAttribute(context, segmentLengthAttr);
1517     }
1518   }
1519 
1520   // Delegate to create.
1521   return PyOperation::create(name,
1522                              /*results=*/std::move(resultTypes),
1523                              /*operands=*/std::move(operands),
1524                              /*attributes=*/std::move(attributes),
1525                              /*successors=*/std::move(successors),
1526                              /*regions=*/*regions, location, maybeIp);
1527 }
1528 
1529 PyOpView::PyOpView(const py::object &operationObject)
1530     // Casting through the PyOperationBase base-class and then back to the
1531     // Operation lets us accept any PyOperationBase subclass.
1532     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1533       operationObject(operation.getRef().getObject()) {}
1534 
1535 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1536   // This is... a little gross. The typical pattern is to have a pure python
1537   // class that extends OpView like:
1538   //   class AddFOp(_cext.ir.OpView):
1539   //     def __init__(self, loc, lhs, rhs):
1540   //       operation = loc.context.create_operation(
1541   //           "addf", lhs, rhs, results=[lhs.type])
1542   //       super().__init__(operation)
1543   //
1544   // I.e. The goal of the user facing type is to provide a nice constructor
1545   // that has complete freedom for the op under construction. This is at odds
1546   // with our other desire to sometimes create this object by just passing an
1547   // operation (to initialize the base class). We could do *arg and **kwargs
1548   // munging to try to make it work, but instead, we synthesize a new class
1549   // on the fly which extends this user class (AddFOp in this example) and
1550   // *give it* the base class's __init__ method, thus bypassing the
1551   // intermediate subclass's __init__ method entirely. While slightly,
1552   // underhanded, this is safe/legal because the type hierarchy has not changed
1553   // (we just added a new leaf) and we aren't mucking around with __new__.
1554   // Typically, this new class will be stored on the original as "_Raw" and will
1555   // be used for casts and other things that need a variant of the class that
1556   // is initialized purely from an operation.
1557   py::object parentMetaclass =
1558       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1559   py::dict attributes;
1560   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1561   // now.
1562   //   auto opViewType = py::type::of<PyOpView>();
1563   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1564   attributes["__init__"] = opViewType.attr("__init__");
1565   py::str origName = userClass.attr("__name__");
1566   py::str newName = py::str("_") + origName;
1567   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1568 }
1569 
1570 //------------------------------------------------------------------------------
1571 // PyInsertionPoint.
1572 //------------------------------------------------------------------------------
1573 
1574 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1575 
1576 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1577     : refOperation(beforeOperationBase.getOperation().getRef()),
1578       block((*refOperation)->getBlock()) {}
1579 
1580 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1581   PyOperation &operation = operationBase.getOperation();
1582   if (operation.isAttached())
1583     throw SetPyError(PyExc_ValueError,
1584                      "Attempt to insert operation that is already attached");
1585   block.getParentOperation()->checkValid();
1586   MlirOperation beforeOp = {nullptr};
1587   if (refOperation) {
1588     // Insert before operation.
1589     (*refOperation)->checkValid();
1590     beforeOp = (*refOperation)->get();
1591   } else {
1592     // Insert at end (before null) is only valid if the block does not
1593     // already end in a known terminator (violating this will cause assertion
1594     // failures later).
1595     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1596       throw py::index_error("Cannot insert operation at the end of a block "
1597                             "that already has a terminator. Did you mean to "
1598                             "use 'InsertionPoint.at_block_terminator(block)' "
1599                             "versus 'InsertionPoint(block)'?");
1600     }
1601   }
1602   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1603   operation.setAttached();
1604 }
1605 
1606 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1607   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1608   if (mlirOperationIsNull(firstOp)) {
1609     // Just insert at end.
1610     return PyInsertionPoint(block);
1611   }
1612 
1613   // Insert before first op.
1614   PyOperationRef firstOpRef = PyOperation::forOperation(
1615       block.getParentOperation()->getContext(), firstOp);
1616   return PyInsertionPoint{block, std::move(firstOpRef)};
1617 }
1618 
1619 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1620   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1621   if (mlirOperationIsNull(terminator))
1622     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1623   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1624       block.getParentOperation()->getContext(), terminator);
1625   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1626 }
1627 
1628 py::object PyInsertionPoint::contextEnter() {
1629   return PyThreadContextEntry::pushInsertionPoint(*this);
1630 }
1631 
1632 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1633                                    const pybind11::object &excVal,
1634                                    const pybind11::object &excTb) {
1635   PyThreadContextEntry::popInsertionPoint(*this);
1636 }
1637 
1638 //------------------------------------------------------------------------------
1639 // PyAttribute.
1640 //------------------------------------------------------------------------------
1641 
1642 bool PyAttribute::operator==(const PyAttribute &other) {
1643   return mlirAttributeEqual(attr, other.attr);
1644 }
1645 
1646 py::object PyAttribute::getCapsule() {
1647   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1648 }
1649 
1650 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1651   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1652   if (mlirAttributeIsNull(rawAttr))
1653     throw py::error_already_set();
1654   return PyAttribute(
1655       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1656 }
1657 
1658 //------------------------------------------------------------------------------
1659 // PyNamedAttribute.
1660 //------------------------------------------------------------------------------
1661 
1662 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1663     : ownedName(new std::string(std::move(ownedName))) {
1664   namedAttr = mlirNamedAttributeGet(
1665       mlirIdentifierGet(mlirAttributeGetContext(attr),
1666                         toMlirStringRef(*this->ownedName)),
1667       attr);
1668 }
1669 
1670 //------------------------------------------------------------------------------
1671 // PyType.
1672 //------------------------------------------------------------------------------
1673 
1674 bool PyType::operator==(const PyType &other) {
1675   return mlirTypeEqual(type, other.type);
1676 }
1677 
1678 py::object PyType::getCapsule() {
1679   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1680 }
1681 
1682 PyType PyType::createFromCapsule(py::object capsule) {
1683   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1684   if (mlirTypeIsNull(rawType))
1685     throw py::error_already_set();
1686   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1687                 rawType);
1688 }
1689 
1690 //------------------------------------------------------------------------------
1691 // PyValue and subclases.
1692 //------------------------------------------------------------------------------
1693 
1694 pybind11::object PyValue::getCapsule() {
1695   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1696 }
1697 
1698 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1699   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1700   if (mlirValueIsNull(value))
1701     throw py::error_already_set();
1702   MlirOperation owner;
1703   if (mlirValueIsAOpResult(value))
1704     owner = mlirOpResultGetOwner(value);
1705   if (mlirValueIsABlockArgument(value))
1706     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1707   if (mlirOperationIsNull(owner))
1708     throw py::error_already_set();
1709   MlirContext ctx = mlirOperationGetContext(owner);
1710   PyOperationRef ownerRef =
1711       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1712   return PyValue(ownerRef, value);
1713 }
1714 
1715 //------------------------------------------------------------------------------
1716 // PySymbolTable.
1717 //------------------------------------------------------------------------------
1718 
1719 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1720     : operation(operation.getOperation().getRef()) {
1721   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1722   if (mlirSymbolTableIsNull(symbolTable)) {
1723     throw py::cast_error("Operation is not a Symbol Table.");
1724   }
1725 }
1726 
1727 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1728   operation->checkValid();
1729   MlirOperation symbol = mlirSymbolTableLookup(
1730       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1731   if (mlirOperationIsNull(symbol))
1732     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1733 
1734   return PyOperation::forOperation(operation->getContext(), symbol,
1735                                    operation.getObject())
1736       ->createOpView();
1737 }
1738 
1739 void PySymbolTable::erase(PyOperationBase &symbol) {
1740   operation->checkValid();
1741   symbol.getOperation().checkValid();
1742   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1743   // The operation is also erased, so we must invalidate it. There may be Python
1744   // references to this operation so we don't want to delete it from the list of
1745   // live operations here.
1746   symbol.getOperation().valid = false;
1747 }
1748 
1749 void PySymbolTable::dunderDel(const std::string &name) {
1750   py::object operation = dunderGetItem(name);
1751   erase(py::cast<PyOperationBase &>(operation));
1752 }
1753 
1754 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1755   operation->checkValid();
1756   symbol.getOperation().checkValid();
1757   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1758       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1759   if (mlirAttributeIsNull(symbolAttr))
1760     throw py::value_error("Expected operation to have a symbol name.");
1761   return PyAttribute(
1762       symbol.getOperation().getContext(),
1763       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1764 }
1765 
1766 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1767   // Op must already be a symbol.
1768   PyOperation &operation = symbol.getOperation();
1769   operation.checkValid();
1770   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1771   MlirAttribute existingNameAttr =
1772       mlirOperationGetAttributeByName(operation.get(), attrName);
1773   if (mlirAttributeIsNull(existingNameAttr))
1774     throw py::value_error("Expected operation to have a symbol name.");
1775   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1776 }
1777 
1778 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1779                                   const std::string &name) {
1780   // Op must already be a symbol.
1781   PyOperation &operation = symbol.getOperation();
1782   operation.checkValid();
1783   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1784   MlirAttribute existingNameAttr =
1785       mlirOperationGetAttributeByName(operation.get(), attrName);
1786   if (mlirAttributeIsNull(existingNameAttr))
1787     throw py::value_error("Expected operation to have a symbol name.");
1788   MlirAttribute newNameAttr =
1789       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1790   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1791 }
1792 
1793 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1794   PyOperation &operation = symbol.getOperation();
1795   operation.checkValid();
1796   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1797   MlirAttribute existingVisAttr =
1798       mlirOperationGetAttributeByName(operation.get(), attrName);
1799   if (mlirAttributeIsNull(existingVisAttr))
1800     throw py::value_error("Expected operation to have a symbol visibility.");
1801   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1802 }
1803 
1804 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1805                                   const std::string &visibility) {
1806   if (visibility != "public" && visibility != "private" &&
1807       visibility != "nested")
1808     throw py::value_error(
1809         "Expected visibility to be 'public', 'private' or 'nested'");
1810   PyOperation &operation = symbol.getOperation();
1811   operation.checkValid();
1812   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1813   MlirAttribute existingVisAttr =
1814       mlirOperationGetAttributeByName(operation.get(), attrName);
1815   if (mlirAttributeIsNull(existingVisAttr))
1816     throw py::value_error("Expected operation to have a symbol visibility.");
1817   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1818                                                toMlirStringRef(visibility));
1819   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1820 }
1821 
1822 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1823                                          const std::string &newSymbol,
1824                                          PyOperationBase &from) {
1825   PyOperation &fromOperation = from.getOperation();
1826   fromOperation.checkValid();
1827   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1828           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1829           from.getOperation())))
1830 
1831     throw py::value_error("Symbol rename failed");
1832 }
1833 
1834 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1835                                      bool allSymUsesVisible,
1836                                      py::object callback) {
1837   PyOperation &fromOperation = from.getOperation();
1838   fromOperation.checkValid();
1839   struct UserData {
1840     PyMlirContextRef context;
1841     py::object callback;
1842     bool gotException;
1843     std::string exceptionWhat;
1844     py::object exceptionType;
1845   };
1846   UserData userData{
1847       fromOperation.getContext(), std::move(callback), false, {}, {}};
1848   mlirSymbolTableWalkSymbolTables(
1849       fromOperation.get(), allSymUsesVisible,
1850       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1851         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1852         auto pyFoundOp =
1853             PyOperation::forOperation(calleeUserData->context, foundOp);
1854         if (calleeUserData->gotException)
1855           return;
1856         try {
1857           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1858         } catch (py::error_already_set &e) {
1859           calleeUserData->gotException = true;
1860           calleeUserData->exceptionWhat = e.what();
1861           calleeUserData->exceptionType = e.type();
1862         }
1863       },
1864       static_cast<void *>(&userData));
1865   if (userData.gotException) {
1866     std::string message("Exception raised in callback: ");
1867     message.append(userData.exceptionWhat);
1868     throw std::runtime_error(message);
1869   }
1870 }
1871 
1872 namespace {
1873 /// CRTP base class for Python MLIR values that subclass Value and should be
1874 /// castable from it. The value hierarchy is one level deep and is not supposed
1875 /// to accommodate other levels unless core MLIR changes.
1876 template <typename DerivedTy>
1877 class PyConcreteValue : public PyValue {
1878 public:
1879   // Derived classes must define statics for:
1880   //   IsAFunctionTy isaFunction
1881   //   const char *pyClassName
1882   // and redefine bindDerived.
1883   using ClassTy = py::class_<DerivedTy, PyValue>;
1884   using IsAFunctionTy = bool (*)(MlirValue);
1885 
1886   PyConcreteValue() = default;
1887   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1888       : PyValue(operationRef, value) {}
1889   PyConcreteValue(PyValue &orig)
1890       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1891 
1892   /// Attempts to cast the original value to the derived type and throws on
1893   /// type mismatches.
1894   static MlirValue castFrom(PyValue &orig) {
1895     if (!DerivedTy::isaFunction(orig.get())) {
1896       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1897       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1898                                              DerivedTy::pyClassName +
1899                                              " (from " + origRepr + ")");
1900     }
1901     return orig.get();
1902   }
1903 
1904   /// Binds the Python module objects to functions of this class.
1905   static void bind(py::module &m) {
1906     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1907     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1908     cls.def_static(
1909         "isinstance",
1910         [](PyValue &otherValue) -> bool {
1911           return DerivedTy::isaFunction(otherValue);
1912         },
1913         py::arg("other_value"));
1914     DerivedTy::bindDerived(cls);
1915   }
1916 
1917   /// Implemented by derived classes to add methods to the Python subclass.
1918   static void bindDerived(ClassTy &m) {}
1919 };
1920 
1921 /// Python wrapper for MlirBlockArgument.
1922 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1923 public:
1924   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1925   static constexpr const char *pyClassName = "BlockArgument";
1926   using PyConcreteValue::PyConcreteValue;
1927 
1928   static void bindDerived(ClassTy &c) {
1929     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1930       return PyBlock(self.getParentOperation(),
1931                      mlirBlockArgumentGetOwner(self.get()));
1932     });
1933     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1934       return mlirBlockArgumentGetArgNumber(self.get());
1935     });
1936     c.def(
1937         "set_type",
1938         [](PyBlockArgument &self, PyType type) {
1939           return mlirBlockArgumentSetType(self.get(), type);
1940         },
1941         py::arg("type"));
1942   }
1943 };
1944 
1945 /// Python wrapper for MlirOpResult.
1946 class PyOpResult : public PyConcreteValue<PyOpResult> {
1947 public:
1948   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1949   static constexpr const char *pyClassName = "OpResult";
1950   using PyConcreteValue::PyConcreteValue;
1951 
1952   static void bindDerived(ClassTy &c) {
1953     c.def_property_readonly("owner", [](PyOpResult &self) {
1954       assert(
1955           mlirOperationEqual(self.getParentOperation()->get(),
1956                              mlirOpResultGetOwner(self.get())) &&
1957           "expected the owner of the value in Python to match that in the IR");
1958       return self.getParentOperation().getObject();
1959     });
1960     c.def_property_readonly("result_number", [](PyOpResult &self) {
1961       return mlirOpResultGetResultNumber(self.get());
1962     });
1963   }
1964 };
1965 
1966 /// Returns the list of types of the values held by container.
1967 template <typename Container>
1968 static std::vector<PyType> getValueTypes(Container &container,
1969                                          PyMlirContextRef &context) {
1970   std::vector<PyType> result;
1971   result.reserve(container.size());
1972   for (int i = 0, e = container.size(); i < e; ++i) {
1973     result.push_back(
1974         PyType(context, mlirValueGetType(container.getElement(i).get())));
1975   }
1976   return result;
1977 }
1978 
1979 /// A list of block arguments. Internally, these are stored as consecutive
1980 /// elements, random access is cheap. The argument list is associated with the
1981 /// operation that contains the block (detached blocks are not allowed in
1982 /// Python bindings) and extends its lifetime.
1983 class PyBlockArgumentList
1984     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1985 public:
1986   static constexpr const char *pyClassName = "BlockArgumentList";
1987 
1988   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1989                       intptr_t startIndex = 0, intptr_t length = -1,
1990                       intptr_t step = 1)
1991       : Sliceable(startIndex,
1992                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1993                   step),
1994         operation(std::move(operation)), block(block) {}
1995 
1996   static void bindDerived(ClassTy &c) {
1997     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1998       return getValueTypes(self, self.operation->getContext());
1999     });
2000   }
2001 
2002 private:
2003   /// Give the parent CRTP class access to hook implementations below.
2004   friend class Sliceable<PyBlockArgumentList, PyBlockArgument>;
2005 
2006   /// Returns the number of arguments in the list.
2007   intptr_t getRawNumElements() {
2008     operation->checkValid();
2009     return mlirBlockGetNumArguments(block);
2010   }
2011 
2012   /// Returns `pos`-the element in the list.
2013   PyBlockArgument getRawElement(intptr_t pos) {
2014     MlirValue argument = mlirBlockGetArgument(block, pos);
2015     return PyBlockArgument(operation, argument);
2016   }
2017 
2018   /// Returns a sublist of this list.
2019   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
2020                             intptr_t step) {
2021     return PyBlockArgumentList(operation, block, startIndex, length, step);
2022   }
2023 
2024   PyOperationRef operation;
2025   MlirBlock block;
2026 };
2027 
2028 /// A list of operation operands. Internally, these are stored as consecutive
2029 /// elements, random access is cheap. The result list is associated with the
2030 /// operation whose results these are, and extends the lifetime of this
2031 /// operation.
2032 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2033 public:
2034   static constexpr const char *pyClassName = "OpOperandList";
2035 
2036   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2037                   intptr_t length = -1, intptr_t step = 1)
2038       : Sliceable(startIndex,
2039                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2040                                : length,
2041                   step),
2042         operation(operation) {}
2043 
2044   void dunderSetItem(intptr_t index, PyValue value) {
2045     index = wrapIndex(index);
2046     mlirOperationSetOperand(operation->get(), index, value.get());
2047   }
2048 
2049   static void bindDerived(ClassTy &c) {
2050     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2051   }
2052 
2053 private:
2054   /// Give the parent CRTP class access to hook implementations below.
2055   friend class Sliceable<PyOpOperandList, PyValue>;
2056 
2057   intptr_t getRawNumElements() {
2058     operation->checkValid();
2059     return mlirOperationGetNumOperands(operation->get());
2060   }
2061 
2062   PyValue getRawElement(intptr_t pos) {
2063     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2064     MlirOperation owner;
2065     if (mlirValueIsAOpResult(operand))
2066       owner = mlirOpResultGetOwner(operand);
2067     else if (mlirValueIsABlockArgument(operand))
2068       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2069     else
2070       assert(false && "Value must be an block arg or op result.");
2071     PyOperationRef pyOwner =
2072         PyOperation::forOperation(operation->getContext(), owner);
2073     return PyValue(pyOwner, operand);
2074   }
2075 
2076   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2077     return PyOpOperandList(operation, startIndex, length, step);
2078   }
2079 
2080   PyOperationRef operation;
2081 };
2082 
2083 /// A list of operation results. Internally, these are stored as consecutive
2084 /// elements, random access is cheap. The result list is associated with the
2085 /// operation whose results these are, and extends the lifetime of this
2086 /// operation.
2087 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2088 public:
2089   static constexpr const char *pyClassName = "OpResultList";
2090 
2091   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2092                  intptr_t length = -1, intptr_t step = 1)
2093       : Sliceable(startIndex,
2094                   length == -1 ? mlirOperationGetNumResults(operation->get())
2095                                : length,
2096                   step),
2097         operation(operation) {}
2098 
2099   static void bindDerived(ClassTy &c) {
2100     c.def_property_readonly("types", [](PyOpResultList &self) {
2101       return getValueTypes(self, self.operation->getContext());
2102     });
2103   }
2104 
2105 private:
2106   /// Give the parent CRTP class access to hook implementations below.
2107   friend class Sliceable<PyOpResultList, PyOpResult>;
2108 
2109   intptr_t getRawNumElements() {
2110     operation->checkValid();
2111     return mlirOperationGetNumResults(operation->get());
2112   }
2113 
2114   PyOpResult getRawElement(intptr_t index) {
2115     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2116     return PyOpResult(value);
2117   }
2118 
2119   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2120     return PyOpResultList(operation, startIndex, length, step);
2121   }
2122 
2123   PyOperationRef operation;
2124 };
2125 
2126 /// A list of operation attributes. Can be indexed by name, producing
2127 /// attributes, or by index, producing named attributes.
2128 class PyOpAttributeMap {
2129 public:
2130   PyOpAttributeMap(PyOperationRef operation)
2131       : operation(std::move(operation)) {}
2132 
2133   PyAttribute dunderGetItemNamed(const std::string &name) {
2134     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2135                                                          toMlirStringRef(name));
2136     if (mlirAttributeIsNull(attr)) {
2137       throw SetPyError(PyExc_KeyError,
2138                        "attempt to access a non-existent attribute");
2139     }
2140     return PyAttribute(operation->getContext(), attr);
2141   }
2142 
2143   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2144     if (index < 0 || index >= dunderLen()) {
2145       throw SetPyError(PyExc_IndexError,
2146                        "attempt to access out of bounds attribute");
2147     }
2148     MlirNamedAttribute namedAttr =
2149         mlirOperationGetAttribute(operation->get(), index);
2150     return PyNamedAttribute(
2151         namedAttr.attribute,
2152         std::string(mlirIdentifierStr(namedAttr.name).data,
2153                     mlirIdentifierStr(namedAttr.name).length));
2154   }
2155 
2156   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2157     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2158                                     attr);
2159   }
2160 
2161   void dunderDelItem(const std::string &name) {
2162     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2163                                                      toMlirStringRef(name));
2164     if (!removed)
2165       throw SetPyError(PyExc_KeyError,
2166                        "attempt to delete a non-existent attribute");
2167   }
2168 
2169   intptr_t dunderLen() {
2170     return mlirOperationGetNumAttributes(operation->get());
2171   }
2172 
2173   bool dunderContains(const std::string &name) {
2174     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2175         operation->get(), toMlirStringRef(name)));
2176   }
2177 
2178   static void bind(py::module &m) {
2179     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2180         .def("__contains__", &PyOpAttributeMap::dunderContains)
2181         .def("__len__", &PyOpAttributeMap::dunderLen)
2182         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2183         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2184         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2185         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2186   }
2187 
2188 private:
2189   PyOperationRef operation;
2190 };
2191 
2192 } // namespace
2193 
2194 //------------------------------------------------------------------------------
2195 // Populates the core exports of the 'ir' submodule.
2196 //------------------------------------------------------------------------------
2197 
2198 void mlir::python::populateIRCore(py::module &m) {
2199   //----------------------------------------------------------------------------
2200   // Enums.
2201   //----------------------------------------------------------------------------
2202   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2203       .value("ERROR", MlirDiagnosticError)
2204       .value("WARNING", MlirDiagnosticWarning)
2205       .value("NOTE", MlirDiagnosticNote)
2206       .value("REMARK", MlirDiagnosticRemark);
2207 
2208   //----------------------------------------------------------------------------
2209   // Mapping of Diagnostics.
2210   //----------------------------------------------------------------------------
2211   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2212       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2213       .def_property_readonly("location", &PyDiagnostic::getLocation)
2214       .def_property_readonly("message", &PyDiagnostic::getMessage)
2215       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2216       .def("__str__", [](PyDiagnostic &self) -> py::str {
2217         if (!self.isValid())
2218           return "<Invalid Diagnostic>";
2219         return self.getMessage();
2220       });
2221 
2222   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2223       .def("detach", &PyDiagnosticHandler::detach)
2224       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2225       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2226       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2227       .def("__exit__", &PyDiagnosticHandler::contextExit);
2228 
2229   //----------------------------------------------------------------------------
2230   // Mapping of MlirContext.
2231   // Note that this is exported as _BaseContext. The containing, Python level
2232   // __init__.py will subclass it with site-specific functionality and set a
2233   // "Context" attribute on this module.
2234   //----------------------------------------------------------------------------
2235   py::class_<PyMlirContext>(m, "_BaseContext", py::module_local())
2236       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2237       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2238       .def("_get_context_again",
2239            [](PyMlirContext &self) {
2240              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2241              return ref.releaseObject();
2242            })
2243       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2244       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2245       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2246       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2247                              &PyMlirContext::getCapsule)
2248       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2249       .def("__enter__", &PyMlirContext::contextEnter)
2250       .def("__exit__", &PyMlirContext::contextExit)
2251       .def_property_readonly_static(
2252           "current",
2253           [](py::object & /*class*/) {
2254             auto *context = PyThreadContextEntry::getDefaultContext();
2255             if (!context)
2256               throw SetPyError(PyExc_ValueError, "No current Context");
2257             return context;
2258           },
2259           "Gets the Context bound to the current thread or raises ValueError")
2260       .def_property_readonly(
2261           "dialects",
2262           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2263           "Gets a container for accessing dialects by name")
2264       .def_property_readonly(
2265           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2266           "Alias for 'dialect'")
2267       .def(
2268           "get_dialect_descriptor",
2269           [=](PyMlirContext &self, std::string &name) {
2270             MlirDialect dialect = mlirContextGetOrLoadDialect(
2271                 self.get(), {name.data(), name.size()});
2272             if (mlirDialectIsNull(dialect)) {
2273               throw SetPyError(PyExc_ValueError,
2274                                Twine("Dialect '") + name + "' not found");
2275             }
2276             return PyDialectDescriptor(self.getRef(), dialect);
2277           },
2278           py::arg("dialect_name"),
2279           "Gets or loads a dialect by name, returning its descriptor object")
2280       .def_property(
2281           "allow_unregistered_dialects",
2282           [](PyMlirContext &self) -> bool {
2283             return mlirContextGetAllowUnregisteredDialects(self.get());
2284           },
2285           [](PyMlirContext &self, bool value) {
2286             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2287           })
2288       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2289            py::arg("callback"),
2290            "Attaches a diagnostic handler that will receive callbacks")
2291       .def(
2292           "enable_multithreading",
2293           [](PyMlirContext &self, bool enable) {
2294             mlirContextEnableMultithreading(self.get(), enable);
2295           },
2296           py::arg("enable"))
2297       .def(
2298           "is_registered_operation",
2299           [](PyMlirContext &self, std::string &name) {
2300             return mlirContextIsRegisteredOperation(
2301                 self.get(), MlirStringRef{name.data(), name.size()});
2302           },
2303           py::arg("operation_name"))
2304       .def(
2305           "append_dialect_registry",
2306           [](PyMlirContext &self, PyDialectRegistry &registry) {
2307             mlirContextAppendDialectRegistry(self.get(), registry);
2308           },
2309           py::arg("registry"))
2310       .def("load_all_available_dialects", [](PyMlirContext &self) {
2311         mlirContextLoadAllAvailableDialects(self.get());
2312       });
2313 
2314   //----------------------------------------------------------------------------
2315   // Mapping of PyDialectDescriptor
2316   //----------------------------------------------------------------------------
2317   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2318       .def_property_readonly("namespace",
2319                              [](PyDialectDescriptor &self) {
2320                                MlirStringRef ns =
2321                                    mlirDialectGetNamespace(self.get());
2322                                return py::str(ns.data, ns.length);
2323                              })
2324       .def("__repr__", [](PyDialectDescriptor &self) {
2325         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2326         std::string repr("<DialectDescriptor ");
2327         repr.append(ns.data, ns.length);
2328         repr.append(">");
2329         return repr;
2330       });
2331 
2332   //----------------------------------------------------------------------------
2333   // Mapping of PyDialects
2334   //----------------------------------------------------------------------------
2335   py::class_<PyDialects>(m, "Dialects", py::module_local())
2336       .def("__getitem__",
2337            [=](PyDialects &self, std::string keyName) {
2338              MlirDialect dialect =
2339                  self.getDialectForKey(keyName, /*attrError=*/false);
2340              py::object descriptor =
2341                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2342              return createCustomDialectWrapper(keyName, std::move(descriptor));
2343            })
2344       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2345         MlirDialect dialect =
2346             self.getDialectForKey(attrName, /*attrError=*/true);
2347         py::object descriptor =
2348             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2349         return createCustomDialectWrapper(attrName, std::move(descriptor));
2350       });
2351 
2352   //----------------------------------------------------------------------------
2353   // Mapping of PyDialect
2354   //----------------------------------------------------------------------------
2355   py::class_<PyDialect>(m, "Dialect", py::module_local())
2356       .def(py::init<py::object>(), py::arg("descriptor"))
2357       .def_property_readonly(
2358           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2359       .def("__repr__", [](py::object self) {
2360         auto clazz = self.attr("__class__");
2361         return py::str("<Dialect ") +
2362                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2363                clazz.attr("__module__") + py::str(".") +
2364                clazz.attr("__name__") + py::str(")>");
2365       });
2366 
2367   //----------------------------------------------------------------------------
2368   // Mapping of PyDialectRegistry
2369   //----------------------------------------------------------------------------
2370   py::class_<PyDialectRegistry>(m, "DialectRegistry", py::module_local())
2371       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2372                              &PyDialectRegistry::getCapsule)
2373       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule)
2374       .def(py::init<>());
2375 
2376   //----------------------------------------------------------------------------
2377   // Mapping of Location
2378   //----------------------------------------------------------------------------
2379   py::class_<PyLocation>(m, "Location", py::module_local())
2380       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2381       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2382       .def("__enter__", &PyLocation::contextEnter)
2383       .def("__exit__", &PyLocation::contextExit)
2384       .def("__eq__",
2385            [](PyLocation &self, PyLocation &other) -> bool {
2386              return mlirLocationEqual(self, other);
2387            })
2388       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2389       .def_property_readonly_static(
2390           "current",
2391           [](py::object & /*class*/) {
2392             auto *loc = PyThreadContextEntry::getDefaultLocation();
2393             if (!loc)
2394               throw SetPyError(PyExc_ValueError, "No current Location");
2395             return loc;
2396           },
2397           "Gets the Location bound to the current thread or raises ValueError")
2398       .def_static(
2399           "unknown",
2400           [](DefaultingPyMlirContext context) {
2401             return PyLocation(context->getRef(),
2402                               mlirLocationUnknownGet(context->get()));
2403           },
2404           py::arg("context") = py::none(),
2405           "Gets a Location representing an unknown location")
2406       .def_static(
2407           "callsite",
2408           [](PyLocation callee, const std::vector<PyLocation> &frames,
2409              DefaultingPyMlirContext context) {
2410             if (frames.empty())
2411               throw py::value_error("No caller frames provided");
2412             MlirLocation caller = frames.back().get();
2413             for (const PyLocation &frame :
2414                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2415               caller = mlirLocationCallSiteGet(frame.get(), caller);
2416             return PyLocation(context->getRef(),
2417                               mlirLocationCallSiteGet(callee.get(), caller));
2418           },
2419           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2420           kContextGetCallSiteLocationDocstring)
2421       .def_static(
2422           "file",
2423           [](std::string filename, int line, int col,
2424              DefaultingPyMlirContext context) {
2425             return PyLocation(
2426                 context->getRef(),
2427                 mlirLocationFileLineColGet(
2428                     context->get(), toMlirStringRef(filename), line, col));
2429           },
2430           py::arg("filename"), py::arg("line"), py::arg("col"),
2431           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2432       .def_static(
2433           "fused",
2434           [](const std::vector<PyLocation> &pyLocations,
2435              llvm::Optional<PyAttribute> metadata,
2436              DefaultingPyMlirContext context) {
2437             llvm::SmallVector<MlirLocation, 4> locations;
2438             locations.reserve(pyLocations.size());
2439             for (auto &pyLocation : pyLocations)
2440               locations.push_back(pyLocation.get());
2441             MlirLocation location = mlirLocationFusedGet(
2442                 context->get(), locations.size(), locations.data(),
2443                 metadata ? metadata->get() : MlirAttribute{0});
2444             return PyLocation(context->getRef(), location);
2445           },
2446           py::arg("locations"), py::arg("metadata") = py::none(),
2447           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2448       .def_static(
2449           "name",
2450           [](std::string name, llvm::Optional<PyLocation> childLoc,
2451              DefaultingPyMlirContext context) {
2452             return PyLocation(
2453                 context->getRef(),
2454                 mlirLocationNameGet(
2455                     context->get(), toMlirStringRef(name),
2456                     childLoc ? childLoc->get()
2457                              : mlirLocationUnknownGet(context->get())));
2458           },
2459           py::arg("name"), py::arg("childLoc") = py::none(),
2460           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2461       .def_property_readonly(
2462           "context",
2463           [](PyLocation &self) { return self.getContext().getObject(); },
2464           "Context that owns the Location")
2465       .def(
2466           "emit_error",
2467           [](PyLocation &self, std::string message) {
2468             mlirEmitError(self, message.c_str());
2469           },
2470           py::arg("message"), "Emits an error at this location")
2471       .def("__repr__", [](PyLocation &self) {
2472         PyPrintAccumulator printAccum;
2473         mlirLocationPrint(self, printAccum.getCallback(),
2474                           printAccum.getUserData());
2475         return printAccum.join();
2476       });
2477 
2478   //----------------------------------------------------------------------------
2479   // Mapping of Module
2480   //----------------------------------------------------------------------------
2481   py::class_<PyModule>(m, "Module", py::module_local())
2482       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2483       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2484       .def_static(
2485           "parse",
2486           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2487             MlirModule module = mlirModuleCreateParse(
2488                 context->get(), toMlirStringRef(moduleAsm));
2489             // TODO: Rework error reporting once diagnostic engine is exposed
2490             // in C API.
2491             if (mlirModuleIsNull(module)) {
2492               throw SetPyError(
2493                   PyExc_ValueError,
2494                   "Unable to parse module assembly (see diagnostics)");
2495             }
2496             return PyModule::forModule(module).releaseObject();
2497           },
2498           py::arg("asm"), py::arg("context") = py::none(),
2499           kModuleParseDocstring)
2500       .def_static(
2501           "create",
2502           [](DefaultingPyLocation loc) {
2503             MlirModule module = mlirModuleCreateEmpty(loc);
2504             return PyModule::forModule(module).releaseObject();
2505           },
2506           py::arg("loc") = py::none(), "Creates an empty module")
2507       .def_property_readonly(
2508           "context",
2509           [](PyModule &self) { return self.getContext().getObject(); },
2510           "Context that created the Module")
2511       .def_property_readonly(
2512           "operation",
2513           [](PyModule &self) {
2514             return PyOperation::forOperation(self.getContext(),
2515                                              mlirModuleGetOperation(self.get()),
2516                                              self.getRef().releaseObject())
2517                 .releaseObject();
2518           },
2519           "Accesses the module as an operation")
2520       .def_property_readonly(
2521           "body",
2522           [](PyModule &self) {
2523             PyOperationRef moduleOp = PyOperation::forOperation(
2524                 self.getContext(), mlirModuleGetOperation(self.get()),
2525                 self.getRef().releaseObject());
2526             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2527             return returnBlock;
2528           },
2529           "Return the block for this module")
2530       .def(
2531           "dump",
2532           [](PyModule &self) {
2533             mlirOperationDump(mlirModuleGetOperation(self.get()));
2534           },
2535           kDumpDocstring)
2536       .def(
2537           "__str__",
2538           [](py::object self) {
2539             // Defer to the operation's __str__.
2540             return self.attr("operation").attr("__str__")();
2541           },
2542           kOperationStrDunderDocstring);
2543 
2544   //----------------------------------------------------------------------------
2545   // Mapping of Operation.
2546   //----------------------------------------------------------------------------
2547   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2548       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2549                              [](PyOperationBase &self) {
2550                                return self.getOperation().getCapsule();
2551                              })
2552       .def("__eq__",
2553            [](PyOperationBase &self, PyOperationBase &other) {
2554              return &self.getOperation() == &other.getOperation();
2555            })
2556       .def("__eq__",
2557            [](PyOperationBase &self, py::object other) { return false; })
2558       .def("__hash__",
2559            [](PyOperationBase &self) {
2560              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2561            })
2562       .def_property_readonly("attributes",
2563                              [](PyOperationBase &self) {
2564                                return PyOpAttributeMap(
2565                                    self.getOperation().getRef());
2566                              })
2567       .def_property_readonly("operands",
2568                              [](PyOperationBase &self) {
2569                                return PyOpOperandList(
2570                                    self.getOperation().getRef());
2571                              })
2572       .def_property_readonly("regions",
2573                              [](PyOperationBase &self) {
2574                                return PyRegionList(
2575                                    self.getOperation().getRef());
2576                              })
2577       .def_property_readonly(
2578           "results",
2579           [](PyOperationBase &self) {
2580             return PyOpResultList(self.getOperation().getRef());
2581           },
2582           "Returns the list of Operation results.")
2583       .def_property_readonly(
2584           "result",
2585           [](PyOperationBase &self) {
2586             auto &operation = self.getOperation();
2587             auto numResults = mlirOperationGetNumResults(operation);
2588             if (numResults != 1) {
2589               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2590               throw SetPyError(
2591                   PyExc_ValueError,
2592                   Twine("Cannot call .result on operation ") +
2593                       StringRef(name.data, name.length) + " which has " +
2594                       Twine(numResults) +
2595                       " results (it is only valid for operations with a "
2596                       "single result)");
2597             }
2598             return PyOpResult(operation.getRef(),
2599                               mlirOperationGetResult(operation, 0));
2600           },
2601           "Shortcut to get an op result if it has only one (throws an error "
2602           "otherwise).")
2603       .def_property_readonly(
2604           "location",
2605           [](PyOperationBase &self) {
2606             PyOperation &operation = self.getOperation();
2607             return PyLocation(operation.getContext(),
2608                               mlirOperationGetLocation(operation.get()));
2609           },
2610           "Returns the source location the operation was defined or derived "
2611           "from.")
2612       .def(
2613           "__str__",
2614           [](PyOperationBase &self) {
2615             return self.getAsm(/*binary=*/false,
2616                                /*largeElementsLimit=*/llvm::None,
2617                                /*enableDebugInfo=*/false,
2618                                /*prettyDebugInfo=*/false,
2619                                /*printGenericOpForm=*/false,
2620                                /*useLocalScope=*/false,
2621                                /*assumeVerified=*/false);
2622           },
2623           "Returns the assembly form of the operation.")
2624       .def("print", &PyOperationBase::print,
2625            // Careful: Lots of arguments must match up with print method.
2626            py::arg("file") = py::none(), py::arg("binary") = false,
2627            py::arg("large_elements_limit") = py::none(),
2628            py::arg("enable_debug_info") = false,
2629            py::arg("pretty_debug_info") = false,
2630            py::arg("print_generic_op_form") = false,
2631            py::arg("use_local_scope") = false,
2632            py::arg("assume_verified") = false, kOperationPrintDocstring)
2633       .def("get_asm", &PyOperationBase::getAsm,
2634            // Careful: Lots of arguments must match up with get_asm method.
2635            py::arg("binary") = false,
2636            py::arg("large_elements_limit") = py::none(),
2637            py::arg("enable_debug_info") = false,
2638            py::arg("pretty_debug_info") = false,
2639            py::arg("print_generic_op_form") = false,
2640            py::arg("use_local_scope") = false,
2641            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2642       .def(
2643           "verify",
2644           [](PyOperationBase &self) {
2645             return mlirOperationVerify(self.getOperation());
2646           },
2647           "Verify the operation and return true if it passes, false if it "
2648           "fails.")
2649       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2650            "Puts self immediately after the other operation in its parent "
2651            "block.")
2652       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2653            "Puts self immediately before the other operation in its parent "
2654            "block.")
2655       .def(
2656           "detach_from_parent",
2657           [](PyOperationBase &self) {
2658             PyOperation &operation = self.getOperation();
2659             operation.checkValid();
2660             if (!operation.isAttached())
2661               throw py::value_error("Detached operation has no parent.");
2662 
2663             operation.detachFromParent();
2664             return operation.createOpView();
2665           },
2666           "Detaches the operation from its parent block.");
2667 
2668   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2669       .def_static("create", &PyOperation::create, py::arg("name"),
2670                   py::arg("results") = py::none(),
2671                   py::arg("operands") = py::none(),
2672                   py::arg("attributes") = py::none(),
2673                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2674                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2675                   kOperationCreateDocstring)
2676       .def_property_readonly("parent",
2677                              [](PyOperation &self) -> py::object {
2678                                auto parent = self.getParentOperation();
2679                                if (parent)
2680                                  return parent->getObject();
2681                                return py::none();
2682                              })
2683       .def("erase", &PyOperation::erase)
2684       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2685       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2686                              &PyOperation::getCapsule)
2687       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2688       .def_property_readonly("name",
2689                              [](PyOperation &self) {
2690                                self.checkValid();
2691                                MlirOperation operation = self.get();
2692                                MlirStringRef name = mlirIdentifierStr(
2693                                    mlirOperationGetName(operation));
2694                                return py::str(name.data, name.length);
2695                              })
2696       .def_property_readonly(
2697           "context",
2698           [](PyOperation &self) {
2699             self.checkValid();
2700             return self.getContext().getObject();
2701           },
2702           "Context that owns the Operation")
2703       .def_property_readonly("opview", &PyOperation::createOpView);
2704 
2705   auto opViewClass =
2706       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2707           .def(py::init<py::object>(), py::arg("operation"))
2708           .def_property_readonly("operation", &PyOpView::getOperationObject)
2709           .def_property_readonly(
2710               "context",
2711               [](PyOpView &self) {
2712                 return self.getOperation().getContext().getObject();
2713               },
2714               "Context that owns the Operation")
2715           .def("__str__", [](PyOpView &self) {
2716             return py::str(self.getOperationObject());
2717           });
2718   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2719   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2720   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2721   opViewClass.attr("build_generic") = classmethod(
2722       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2723       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2724       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2725       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2726       "Builds a specific, generated OpView based on class level attributes.");
2727 
2728   //----------------------------------------------------------------------------
2729   // Mapping of PyRegion.
2730   //----------------------------------------------------------------------------
2731   py::class_<PyRegion>(m, "Region", py::module_local())
2732       .def_property_readonly(
2733           "blocks",
2734           [](PyRegion &self) {
2735             return PyBlockList(self.getParentOperation(), self.get());
2736           },
2737           "Returns a forward-optimized sequence of blocks.")
2738       .def_property_readonly(
2739           "owner",
2740           [](PyRegion &self) {
2741             return self.getParentOperation()->createOpView();
2742           },
2743           "Returns the operation owning this region.")
2744       .def(
2745           "__iter__",
2746           [](PyRegion &self) {
2747             self.checkValid();
2748             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2749             return PyBlockIterator(self.getParentOperation(), firstBlock);
2750           },
2751           "Iterates over blocks in the region.")
2752       .def("__eq__",
2753            [](PyRegion &self, PyRegion &other) {
2754              return self.get().ptr == other.get().ptr;
2755            })
2756       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2757 
2758   //----------------------------------------------------------------------------
2759   // Mapping of PyBlock.
2760   //----------------------------------------------------------------------------
2761   py::class_<PyBlock>(m, "Block", py::module_local())
2762       .def_property_readonly(
2763           "owner",
2764           [](PyBlock &self) {
2765             return self.getParentOperation()->createOpView();
2766           },
2767           "Returns the owning operation of this block.")
2768       .def_property_readonly(
2769           "region",
2770           [](PyBlock &self) {
2771             MlirRegion region = mlirBlockGetParentRegion(self.get());
2772             return PyRegion(self.getParentOperation(), region);
2773           },
2774           "Returns the owning region of this block.")
2775       .def_property_readonly(
2776           "arguments",
2777           [](PyBlock &self) {
2778             return PyBlockArgumentList(self.getParentOperation(), self.get());
2779           },
2780           "Returns a list of block arguments.")
2781       .def_property_readonly(
2782           "operations",
2783           [](PyBlock &self) {
2784             return PyOperationList(self.getParentOperation(), self.get());
2785           },
2786           "Returns a forward-optimized sequence of operations.")
2787       .def_static(
2788           "create_at_start",
2789           [](PyRegion &parent, py::list pyArgTypes) {
2790             parent.checkValid();
2791             llvm::SmallVector<MlirType, 4> argTypes;
2792             llvm::SmallVector<MlirLocation, 4> argLocs;
2793             argTypes.reserve(pyArgTypes.size());
2794             argLocs.reserve(pyArgTypes.size());
2795             for (auto &pyArg : pyArgTypes) {
2796               argTypes.push_back(pyArg.cast<PyType &>());
2797               // TODO: Pass in a proper location here.
2798               argLocs.push_back(
2799                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2800             }
2801 
2802             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2803                                               argLocs.data());
2804             mlirRegionInsertOwnedBlock(parent, 0, block);
2805             return PyBlock(parent.getParentOperation(), block);
2806           },
2807           py::arg("parent"), py::arg("arg_types") = py::list(),
2808           "Creates and returns a new Block at the beginning of the given "
2809           "region (with given argument types).")
2810       .def(
2811           "append_to",
2812           [](PyBlock &self, PyRegion &region) {
2813             MlirBlock b = self.get();
2814             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2815               mlirBlockDetach(b);
2816             mlirRegionAppendOwnedBlock(region.get(), b);
2817           },
2818           "Append this block to a region, transferring ownership if necessary")
2819       .def(
2820           "create_before",
2821           [](PyBlock &self, py::args pyArgTypes) {
2822             self.checkValid();
2823             llvm::SmallVector<MlirType, 4> argTypes;
2824             llvm::SmallVector<MlirLocation, 4> argLocs;
2825             argTypes.reserve(pyArgTypes.size());
2826             argLocs.reserve(pyArgTypes.size());
2827             for (auto &pyArg : pyArgTypes) {
2828               argTypes.push_back(pyArg.cast<PyType &>());
2829               // TODO: Pass in a proper location here.
2830               argLocs.push_back(
2831                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2832             }
2833 
2834             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2835                                               argLocs.data());
2836             MlirRegion region = mlirBlockGetParentRegion(self.get());
2837             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2838             return PyBlock(self.getParentOperation(), block);
2839           },
2840           "Creates and returns a new Block before this block "
2841           "(with given argument types).")
2842       .def(
2843           "create_after",
2844           [](PyBlock &self, py::args pyArgTypes) {
2845             self.checkValid();
2846             llvm::SmallVector<MlirType, 4> argTypes;
2847             llvm::SmallVector<MlirLocation, 4> argLocs;
2848             argTypes.reserve(pyArgTypes.size());
2849             argLocs.reserve(pyArgTypes.size());
2850             for (auto &pyArg : pyArgTypes) {
2851               argTypes.push_back(pyArg.cast<PyType &>());
2852 
2853               // TODO: Pass in a proper location here.
2854               argLocs.push_back(
2855                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2856             }
2857             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2858                                               argLocs.data());
2859             MlirRegion region = mlirBlockGetParentRegion(self.get());
2860             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2861             return PyBlock(self.getParentOperation(), block);
2862           },
2863           "Creates and returns a new Block after this block "
2864           "(with given argument types).")
2865       .def(
2866           "__iter__",
2867           [](PyBlock &self) {
2868             self.checkValid();
2869             MlirOperation firstOperation =
2870                 mlirBlockGetFirstOperation(self.get());
2871             return PyOperationIterator(self.getParentOperation(),
2872                                        firstOperation);
2873           },
2874           "Iterates over operations in the block.")
2875       .def("__eq__",
2876            [](PyBlock &self, PyBlock &other) {
2877              return self.get().ptr == other.get().ptr;
2878            })
2879       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2880       .def(
2881           "__str__",
2882           [](PyBlock &self) {
2883             self.checkValid();
2884             PyPrintAccumulator printAccum;
2885             mlirBlockPrint(self.get(), printAccum.getCallback(),
2886                            printAccum.getUserData());
2887             return printAccum.join();
2888           },
2889           "Returns the assembly form of the block.")
2890       .def(
2891           "append",
2892           [](PyBlock &self, PyOperationBase &operation) {
2893             if (operation.getOperation().isAttached())
2894               operation.getOperation().detachFromParent();
2895 
2896             MlirOperation mlirOperation = operation.getOperation().get();
2897             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2898             operation.getOperation().setAttached(
2899                 self.getParentOperation().getObject());
2900           },
2901           py::arg("operation"),
2902           "Appends an operation to this block. If the operation is currently "
2903           "in another block, it will be moved.");
2904 
2905   //----------------------------------------------------------------------------
2906   // Mapping of PyInsertionPoint.
2907   //----------------------------------------------------------------------------
2908 
2909   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2910       .def(py::init<PyBlock &>(), py::arg("block"),
2911            "Inserts after the last operation but still inside the block.")
2912       .def("__enter__", &PyInsertionPoint::contextEnter)
2913       .def("__exit__", &PyInsertionPoint::contextExit)
2914       .def_property_readonly_static(
2915           "current",
2916           [](py::object & /*class*/) {
2917             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2918             if (!ip)
2919               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2920             return ip;
2921           },
2922           "Gets the InsertionPoint bound to the current thread or raises "
2923           "ValueError if none has been set")
2924       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2925            "Inserts before a referenced operation.")
2926       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2927                   py::arg("block"), "Inserts at the beginning of the block.")
2928       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2929                   py::arg("block"), "Inserts before the block terminator.")
2930       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2931            "Inserts an operation.")
2932       .def_property_readonly(
2933           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2934           "Returns the block that this InsertionPoint points to.");
2935 
2936   //----------------------------------------------------------------------------
2937   // Mapping of PyAttribute.
2938   //----------------------------------------------------------------------------
2939   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2940       // Delegate to the PyAttribute copy constructor, which will also lifetime
2941       // extend the backing context which owns the MlirAttribute.
2942       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2943            "Casts the passed attribute to the generic Attribute")
2944       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2945                              &PyAttribute::getCapsule)
2946       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2947       .def_static(
2948           "parse",
2949           [](std::string attrSpec, DefaultingPyMlirContext context) {
2950             MlirAttribute type = mlirAttributeParseGet(
2951                 context->get(), toMlirStringRef(attrSpec));
2952             // TODO: Rework error reporting once diagnostic engine is exposed
2953             // in C API.
2954             if (mlirAttributeIsNull(type)) {
2955               throw SetPyError(PyExc_ValueError,
2956                                Twine("Unable to parse attribute: '") +
2957                                    attrSpec + "'");
2958             }
2959             return PyAttribute(context->getRef(), type);
2960           },
2961           py::arg("asm"), py::arg("context") = py::none(),
2962           "Parses an attribute from an assembly form")
2963       .def_property_readonly(
2964           "context",
2965           [](PyAttribute &self) { return self.getContext().getObject(); },
2966           "Context that owns the Attribute")
2967       .def_property_readonly("type",
2968                              [](PyAttribute &self) {
2969                                return PyType(self.getContext()->getRef(),
2970                                              mlirAttributeGetType(self));
2971                              })
2972       .def(
2973           "get_named",
2974           [](PyAttribute &self, std::string name) {
2975             return PyNamedAttribute(self, std::move(name));
2976           },
2977           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2978       .def("__eq__",
2979            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2980       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2981       .def("__hash__",
2982            [](PyAttribute &self) {
2983              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2984            })
2985       .def(
2986           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2987           kDumpDocstring)
2988       .def(
2989           "__str__",
2990           [](PyAttribute &self) {
2991             PyPrintAccumulator printAccum;
2992             mlirAttributePrint(self, printAccum.getCallback(),
2993                                printAccum.getUserData());
2994             return printAccum.join();
2995           },
2996           "Returns the assembly form of the Attribute.")
2997       .def("__repr__", [](PyAttribute &self) {
2998         // Generally, assembly formats are not printed for __repr__ because
2999         // this can cause exceptionally long debug output and exceptions.
3000         // However, attribute values are generally considered useful and are
3001         // printed. This may need to be re-evaluated if debug dumps end up
3002         // being excessive.
3003         PyPrintAccumulator printAccum;
3004         printAccum.parts.append("Attribute(");
3005         mlirAttributePrint(self, printAccum.getCallback(),
3006                            printAccum.getUserData());
3007         printAccum.parts.append(")");
3008         return printAccum.join();
3009       });
3010 
3011   //----------------------------------------------------------------------------
3012   // Mapping of PyNamedAttribute
3013   //----------------------------------------------------------------------------
3014   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
3015       .def("__repr__",
3016            [](PyNamedAttribute &self) {
3017              PyPrintAccumulator printAccum;
3018              printAccum.parts.append("NamedAttribute(");
3019              printAccum.parts.append(
3020                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
3021                          mlirIdentifierStr(self.namedAttr.name).length));
3022              printAccum.parts.append("=");
3023              mlirAttributePrint(self.namedAttr.attribute,
3024                                 printAccum.getCallback(),
3025                                 printAccum.getUserData());
3026              printAccum.parts.append(")");
3027              return printAccum.join();
3028            })
3029       .def_property_readonly(
3030           "name",
3031           [](PyNamedAttribute &self) {
3032             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
3033                            mlirIdentifierStr(self.namedAttr.name).length);
3034           },
3035           "The name of the NamedAttribute binding")
3036       .def_property_readonly(
3037           "attr",
3038           [](PyNamedAttribute &self) {
3039             // TODO: When named attribute is removed/refactored, also remove
3040             // this constructor (it does an inefficient table lookup).
3041             auto contextRef = PyMlirContext::forContext(
3042                 mlirAttributeGetContext(self.namedAttr.attribute));
3043             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3044           },
3045           py::keep_alive<0, 1>(),
3046           "The underlying generic attribute of the NamedAttribute binding");
3047 
3048   //----------------------------------------------------------------------------
3049   // Mapping of PyType.
3050   //----------------------------------------------------------------------------
3051   py::class_<PyType>(m, "Type", py::module_local())
3052       // Delegate to the PyType copy constructor, which will also lifetime
3053       // extend the backing context which owns the MlirType.
3054       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3055            "Casts the passed type to the generic Type")
3056       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3057       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3058       .def_static(
3059           "parse",
3060           [](std::string typeSpec, DefaultingPyMlirContext context) {
3061             MlirType type =
3062                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3063             // TODO: Rework error reporting once diagnostic engine is exposed
3064             // in C API.
3065             if (mlirTypeIsNull(type)) {
3066               throw SetPyError(PyExc_ValueError,
3067                                Twine("Unable to parse type: '") + typeSpec +
3068                                    "'");
3069             }
3070             return PyType(context->getRef(), type);
3071           },
3072           py::arg("asm"), py::arg("context") = py::none(),
3073           kContextParseTypeDocstring)
3074       .def_property_readonly(
3075           "context", [](PyType &self) { return self.getContext().getObject(); },
3076           "Context that owns the Type")
3077       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3078       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3079       .def("__hash__",
3080            [](PyType &self) {
3081              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3082            })
3083       .def(
3084           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3085       .def(
3086           "__str__",
3087           [](PyType &self) {
3088             PyPrintAccumulator printAccum;
3089             mlirTypePrint(self, printAccum.getCallback(),
3090                           printAccum.getUserData());
3091             return printAccum.join();
3092           },
3093           "Returns the assembly form of the type.")
3094       .def("__repr__", [](PyType &self) {
3095         // Generally, assembly formats are not printed for __repr__ because
3096         // this can cause exceptionally long debug output and exceptions.
3097         // However, types are an exception as they typically have compact
3098         // assembly forms and printing them is useful.
3099         PyPrintAccumulator printAccum;
3100         printAccum.parts.append("Type(");
3101         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3102         printAccum.parts.append(")");
3103         return printAccum.join();
3104       });
3105 
3106   //----------------------------------------------------------------------------
3107   // Mapping of Value.
3108   //----------------------------------------------------------------------------
3109   py::class_<PyValue>(m, "Value", py::module_local())
3110       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3111       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3112       .def_property_readonly(
3113           "context",
3114           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3115           "Context in which the value lives.")
3116       .def(
3117           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3118           kDumpDocstring)
3119       .def_property_readonly(
3120           "owner",
3121           [](PyValue &self) {
3122             assert(mlirOperationEqual(self.getParentOperation()->get(),
3123                                       mlirOpResultGetOwner(self.get())) &&
3124                    "expected the owner of the value in Python to match that in "
3125                    "the IR");
3126             return self.getParentOperation().getObject();
3127           })
3128       .def("__eq__",
3129            [](PyValue &self, PyValue &other) {
3130              return self.get().ptr == other.get().ptr;
3131            })
3132       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3133       .def("__hash__",
3134            [](PyValue &self) {
3135              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3136            })
3137       .def(
3138           "__str__",
3139           [](PyValue &self) {
3140             PyPrintAccumulator printAccum;
3141             printAccum.parts.append("Value(");
3142             mlirValuePrint(self.get(), printAccum.getCallback(),
3143                            printAccum.getUserData());
3144             printAccum.parts.append(")");
3145             return printAccum.join();
3146           },
3147           kValueDunderStrDocstring)
3148       .def_property_readonly("type", [](PyValue &self) {
3149         return PyType(self.getParentOperation()->getContext(),
3150                       mlirValueGetType(self.get()));
3151       });
3152   PyBlockArgument::bind(m);
3153   PyOpResult::bind(m);
3154 
3155   //----------------------------------------------------------------------------
3156   // Mapping of SymbolTable.
3157   //----------------------------------------------------------------------------
3158   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3159       .def(py::init<PyOperationBase &>())
3160       .def("__getitem__", &PySymbolTable::dunderGetItem)
3161       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3162       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3163       .def("__delitem__", &PySymbolTable::dunderDel)
3164       .def("__contains__",
3165            [](PySymbolTable &table, const std::string &name) {
3166              return !mlirOperationIsNull(mlirSymbolTableLookup(
3167                  table, mlirStringRefCreate(name.data(), name.length())));
3168            })
3169       // Static helpers.
3170       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3171                   py::arg("symbol"), py::arg("name"))
3172       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3173                   py::arg("symbol"))
3174       .def_static("get_visibility", &PySymbolTable::getVisibility,
3175                   py::arg("symbol"))
3176       .def_static("set_visibility", &PySymbolTable::setVisibility,
3177                   py::arg("symbol"), py::arg("visibility"))
3178       .def_static("replace_all_symbol_uses",
3179                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3180                   py::arg("new_symbol"), py::arg("from_op"))
3181       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3182                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3183                   py::arg("callback"));
3184 
3185   // Container bindings.
3186   PyBlockArgumentList::bind(m);
3187   PyBlockIterator::bind(m);
3188   PyBlockList::bind(m);
3189   PyOperationIterator::bind(m);
3190   PyOperationList::bind(m);
3191   PyOpAttributeMap::bind(m);
3192   PyOpOperandList::bind(m);
3193   PyOpResultList::bind(m);
3194   PyRegionIterator::bind(m);
3195   PyRegionList::bind(m);
3196 
3197   // Debug bindings.
3198   PyGlobalDebugFlag::bind(m);
3199 }
3200