1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96 )";
97 
98 static const char kOperationGetAsmDocstring[] =
99     R"(Gets the assembly form of the operation with all options available.
100 
101 Args:
102   binary: Whether to return a bytes (True) or str (False) object. Defaults to
103     False.
104   ... others ...: See the print() method for common keyword arguments for
105     configuring the printout.
106 Returns:
107   Either a bytes or str object, depending on the setting of the 'binary'
108   argument.
109 )";
110 
111 static const char kOperationStrDunderDocstring[] =
112     R"(Gets the assembly form of the operation with default options.
113 
114 If more advanced control over the assembly formatting or I/O options is needed,
115 use the dedicated print or get_asm method, which supports keyword arguments to
116 customize behavior.
117 )";
118 
119 static const char kDumpDocstring[] =
120     R"(Dumps a debug representation of the object to stderr.)";
121 
122 static const char kAppendBlockDocstring[] =
123     R"(Appends a new block, with argument types as positional args.
124 
125 Returns:
126   The created block.
127 )";
128 
129 static const char kValueDunderStrDocstring[] =
130     R"(Returns the string form of the value.
131 
132 If the value is a block argument, this is the assembly form of its type and the
133 position in the argument list. If the value is an operation result, this is
134 equivalent to printing the operation that produced it.
135 )";
136 
137 //------------------------------------------------------------------------------
138 // Utilities.
139 //------------------------------------------------------------------------------
140 
141 /// Helper for creating an @classmethod.
142 template <class Func, typename... Args>
143 py::object classmethod(Func f, Args... args) {
144   py::object cf = py::cpp_function(f, args...);
145   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
146 }
147 
148 static py::object
149 createCustomDialectWrapper(const std::string &dialectNamespace,
150                            py::object dialectDescriptor) {
151   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
152   if (!dialectClass) {
153     // Use the base class.
154     return py::cast(PyDialect(std::move(dialectDescriptor)));
155   }
156 
157   // Create the custom implementation.
158   return (*dialectClass)(std::move(dialectDescriptor));
159 }
160 
161 static MlirStringRef toMlirStringRef(const std::string &s) {
162   return mlirStringRefCreate(s.data(), s.size());
163 }
164 
165 /// Wrapper for the global LLVM debugging flag.
166 struct PyGlobalDebugFlag {
167   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
168 
169   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
170 
171   static void bind(py::module &m) {
172     // Debug flags.
173     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
174         .def_property_static("flag", &PyGlobalDebugFlag::get,
175                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
176   }
177 };
178 
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182 
183 namespace {
184 
185 class PyRegionIterator {
186 public:
187   PyRegionIterator(PyOperationRef operation)
188       : operation(std::move(operation)) {}
189 
190   PyRegionIterator &dunderIter() { return *this; }
191 
192   PyRegion dunderNext() {
193     operation->checkValid();
194     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195       throw py::stop_iteration();
196     }
197     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198     return PyRegion(operation, region);
199   }
200 
201   static void bind(py::module &m) {
202     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
203         .def("__iter__", &PyRegionIterator::dunderIter)
204         .def("__next__", &PyRegionIterator::dunderNext);
205   }
206 
207 private:
208   PyOperationRef operation;
209   int nextIndex = 0;
210 };
211 
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
216   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217 
218   intptr_t dunderLen() {
219     operation->checkValid();
220     return mlirOperationGetNumRegions(operation->get());
221   }
222 
223   PyRegion dunderGetItem(intptr_t index) {
224     // dunderLen checks validity.
225     if (index < 0 || index >= dunderLen()) {
226       throw SetPyError(PyExc_IndexError,
227                        "attempt to access out of bounds region");
228     }
229     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230     return PyRegion(operation, region);
231   }
232 
233   static void bind(py::module &m) {
234     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
235         .def("__len__", &PyRegionList::dunderLen)
236         .def("__getitem__", &PyRegionList::dunderGetItem);
237   }
238 
239 private:
240   PyOperationRef operation;
241 };
242 
243 class PyBlockIterator {
244 public:
245   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246       : operation(std::move(operation)), next(next) {}
247 
248   PyBlockIterator &dunderIter() { return *this; }
249 
250   PyBlock dunderNext() {
251     operation->checkValid();
252     if (mlirBlockIsNull(next)) {
253       throw py::stop_iteration();
254     }
255 
256     PyBlock returnBlock(operation, next);
257     next = mlirBlockGetNextInRegion(next);
258     return returnBlock;
259   }
260 
261   static void bind(py::module &m) {
262     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
263         .def("__iter__", &PyBlockIterator::dunderIter)
264         .def("__next__", &PyBlockIterator::dunderNext);
265   }
266 
267 private:
268   PyOperationRef operation;
269   MlirBlock next;
270 };
271 
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
277   PyBlockList(PyOperationRef operation, MlirRegion region)
278       : operation(std::move(operation)), region(region) {}
279 
280   PyBlockIterator dunderIter() {
281     operation->checkValid();
282     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283   }
284 
285   intptr_t dunderLen() {
286     operation->checkValid();
287     intptr_t count = 0;
288     MlirBlock block = mlirRegionGetFirstBlock(region);
289     while (!mlirBlockIsNull(block)) {
290       count += 1;
291       block = mlirBlockGetNextInRegion(block);
292     }
293     return count;
294   }
295 
296   PyBlock dunderGetItem(intptr_t index) {
297     operation->checkValid();
298     if (index < 0) {
299       throw SetPyError(PyExc_IndexError,
300                        "attempt to access out of bounds block");
301     }
302     MlirBlock block = mlirRegionGetFirstBlock(region);
303     while (!mlirBlockIsNull(block)) {
304       if (index == 0) {
305         return PyBlock(operation, block);
306       }
307       block = mlirBlockGetNextInRegion(block);
308       index -= 1;
309     }
310     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311   }
312 
313   PyBlock appendBlock(py::args pyArgTypes) {
314     operation->checkValid();
315     llvm::SmallVector<MlirType, 4> argTypes;
316     argTypes.reserve(pyArgTypes.size());
317     for (auto &pyArg : pyArgTypes) {
318       argTypes.push_back(pyArg.cast<PyType &>());
319     }
320 
321     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322     mlirRegionAppendOwnedBlock(region, block);
323     return PyBlock(operation, block);
324   }
325 
326   static void bind(py::module &m) {
327     py::class_<PyBlockList>(m, "BlockList", py::module_local())
328         .def("__getitem__", &PyBlockList::dunderGetItem)
329         .def("__iter__", &PyBlockList::dunderIter)
330         .def("__len__", &PyBlockList::dunderLen)
331         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332   }
333 
334 private:
335   PyOperationRef operation;
336   MlirRegion region;
337 };
338 
339 class PyOperationIterator {
340 public:
341   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342       : parentOperation(std::move(parentOperation)), next(next) {}
343 
344   PyOperationIterator &dunderIter() { return *this; }
345 
346   py::object dunderNext() {
347     parentOperation->checkValid();
348     if (mlirOperationIsNull(next)) {
349       throw py::stop_iteration();
350     }
351 
352     PyOperationRef returnOperation =
353         PyOperation::forOperation(parentOperation->getContext(), next);
354     next = mlirOperationGetNextInBlock(next);
355     return returnOperation->createOpView();
356   }
357 
358   static void bind(py::module &m) {
359     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
360         .def("__iter__", &PyOperationIterator::dunderIter)
361         .def("__next__", &PyOperationIterator::dunderNext);
362   }
363 
364 private:
365   PyOperationRef parentOperation;
366   MlirOperation next;
367 };
368 
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
375   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376       : parentOperation(std::move(parentOperation)), block(block) {}
377 
378   PyOperationIterator dunderIter() {
379     parentOperation->checkValid();
380     return PyOperationIterator(parentOperation,
381                                mlirBlockGetFirstOperation(block));
382   }
383 
384   intptr_t dunderLen() {
385     parentOperation->checkValid();
386     intptr_t count = 0;
387     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388     while (!mlirOperationIsNull(childOp)) {
389       count += 1;
390       childOp = mlirOperationGetNextInBlock(childOp);
391     }
392     return count;
393   }
394 
395   py::object dunderGetItem(intptr_t index) {
396     parentOperation->checkValid();
397     if (index < 0) {
398       throw SetPyError(PyExc_IndexError,
399                        "attempt to access out of bounds operation");
400     }
401     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402     while (!mlirOperationIsNull(childOp)) {
403       if (index == 0) {
404         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405             ->createOpView();
406       }
407       childOp = mlirOperationGetNextInBlock(childOp);
408       index -= 1;
409     }
410     throw SetPyError(PyExc_IndexError,
411                      "attempt to access out of bounds operation");
412   }
413 
414   static void bind(py::module &m) {
415     py::class_<PyOperationList>(m, "OperationList", py::module_local())
416         .def("__getitem__", &PyOperationList::dunderGetItem)
417         .def("__iter__", &PyOperationList::dunderIter)
418         .def("__len__", &PyOperationList::dunderLen);
419   }
420 
421 private:
422   PyOperationRef parentOperation;
423   MlirBlock block;
424 };
425 
426 } // namespace
427 
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431 
432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433   py::gil_scoped_acquire acquire;
434   auto &liveContexts = getLiveContexts();
435   liveContexts[context.ptr] = this;
436 }
437 
438 PyMlirContext::~PyMlirContext() {
439   // Note that the only public way to construct an instance is via the
440   // forContext method, which always puts the associated handle into
441   // liveContexts.
442   py::gil_scoped_acquire acquire;
443   getLiveContexts().erase(context.ptr);
444   mlirContextDestroy(context);
445 }
446 
447 py::object PyMlirContext::getCapsule() {
448   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450 
451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453   if (mlirContextIsNull(rawContext))
454     throw py::error_already_set();
455   return forContext(rawContext).releaseObject();
456 }
457 
458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459   MlirContext context = mlirContextCreate();
460   mlirRegisterAllDialects(context);
461   return new PyMlirContext(context);
462 }
463 
464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465   py::gil_scoped_acquire acquire;
466   auto &liveContexts = getLiveContexts();
467   auto it = liveContexts.find(context.ptr);
468   if (it == liveContexts.end()) {
469     // Create.
470     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471     py::object pyRef = py::cast(unownedContextWrapper);
472     assert(pyRef && "cast to py::object failed");
473     liveContexts[context.ptr] = unownedContextWrapper;
474     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475   }
476   // Use existing.
477   py::object pyRef = py::cast(it->second);
478   return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480 
481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482   static LiveContextMap liveContexts;
483   return liveContexts;
484 }
485 
486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487 
488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489 
490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491 
492 pybind11::object PyMlirContext::contextEnter() {
493   return PyThreadContextEntry::pushContext(*this);
494 }
495 
496 void PyMlirContext::contextExit(pybind11::object excType,
497                                 pybind11::object excVal,
498                                 pybind11::object excTb) {
499   PyThreadContextEntry::popContext(*this);
500 }
501 
502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504   if (!context) {
505     throw SetPyError(
506         PyExc_RuntimeError,
507         "An MLIR function requires a Context but none was provided in the call "
508         "or from the surrounding environment. Either pass to the function with "
509         "a 'context=' argument or establish a default using 'with Context():'");
510   }
511   return *context;
512 }
513 
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517 
518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519   static thread_local std::vector<PyThreadContextEntry> stack;
520   return stack;
521 }
522 
523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524   auto &stack = getStack();
525   if (stack.empty())
526     return nullptr;
527   return &stack.back();
528 }
529 
530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531                                 py::object insertionPoint,
532                                 py::object location) {
533   auto &stack = getStack();
534   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535                      std::move(location));
536   // If the new stack has more than one entry and the context of the new top
537   // entry matches the previous, copy the insertionPoint and location from the
538   // previous entry if missing from the new top entry.
539   if (stack.size() > 1) {
540     auto &prev = *(stack.rbegin() + 1);
541     auto &current = stack.back();
542     if (current.context.is(prev.context)) {
543       // Default non-context objects from the previous entry.
544       if (!current.insertionPoint)
545         current.insertionPoint = prev.insertionPoint;
546       if (!current.location)
547         current.location = prev.location;
548     }
549   }
550 }
551 
552 PyMlirContext *PyThreadContextEntry::getContext() {
553   if (!context)
554     return nullptr;
555   return py::cast<PyMlirContext *>(context);
556 }
557 
558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559   if (!insertionPoint)
560     return nullptr;
561   return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563 
564 PyLocation *PyThreadContextEntry::getLocation() {
565   if (!location)
566     return nullptr;
567   return py::cast<PyLocation *>(location);
568 }
569 
570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571   auto *tos = getTopOfStack();
572   return tos ? tos->getContext() : nullptr;
573 }
574 
575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576   auto *tos = getTopOfStack();
577   return tos ? tos->getInsertionPoint() : nullptr;
578 }
579 
580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getLocation() : nullptr;
583 }
584 
585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586   py::object contextObj = py::cast(context);
587   push(FrameKind::Context, /*context=*/contextObj,
588        /*insertionPoint=*/py::object(),
589        /*location=*/py::object());
590   return contextObj;
591 }
592 
593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594   auto &stack = getStack();
595   if (stack.empty())
596     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600   stack.pop_back();
601 }
602 
603 py::object
604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605   py::object contextObj =
606       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607   py::object insertionPointObj = py::cast(insertionPoint);
608   push(FrameKind::InsertionPoint,
609        /*context=*/contextObj,
610        /*insertionPoint=*/insertionPointObj,
611        /*location=*/py::object());
612   return insertionPointObj;
613 }
614 
615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616   auto &stack = getStack();
617   if (stack.empty())
618     throw SetPyError(PyExc_RuntimeError,
619                      "Unbalanced InsertionPoint enter/exit");
620   auto &tos = stack.back();
621   if (tos.frameKind != FrameKind::InsertionPoint &&
622       tos.getInsertionPoint() != &insertionPoint)
623     throw SetPyError(PyExc_RuntimeError,
624                      "Unbalanced InsertionPoint enter/exit");
625   stack.pop_back();
626 }
627 
628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629   py::object contextObj = location.getContext().getObject();
630   py::object locationObj = py::cast(location);
631   push(FrameKind::Location, /*context=*/contextObj,
632        /*insertionPoint=*/py::object(),
633        /*location=*/locationObj);
634   return locationObj;
635 }
636 
637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638   auto &stack = getStack();
639   if (stack.empty())
640     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641   auto &tos = stack.back();
642   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644   stack.pop_back();
645 }
646 
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650 
651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652                                          bool attrError) {
653   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
654                                                     {key.data(), key.size()});
655   if (mlirDialectIsNull(dialect)) {
656     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
657                      Twine("Dialect '") + key + "' not found");
658   }
659   return dialect;
660 }
661 
662 //------------------------------------------------------------------------------
663 // PyLocation
664 //------------------------------------------------------------------------------
665 
666 py::object PyLocation::getCapsule() {
667   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
668 }
669 
670 PyLocation PyLocation::createFromCapsule(py::object capsule) {
671   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
672   if (mlirLocationIsNull(rawLoc))
673     throw py::error_already_set();
674   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
675                     rawLoc);
676 }
677 
678 py::object PyLocation::contextEnter() {
679   return PyThreadContextEntry::pushLocation(*this);
680 }
681 
682 void PyLocation::contextExit(py::object excType, py::object excVal,
683                              py::object excTb) {
684   PyThreadContextEntry::popLocation(*this);
685 }
686 
687 PyLocation &DefaultingPyLocation::resolve() {
688   auto *location = PyThreadContextEntry::getDefaultLocation();
689   if (!location) {
690     throw SetPyError(
691         PyExc_RuntimeError,
692         "An MLIR function requires a Location but none was provided in the "
693         "call or from the surrounding environment. Either pass to the function "
694         "with a 'loc=' argument or establish a default using 'with loc:'");
695   }
696   return *location;
697 }
698 
699 //------------------------------------------------------------------------------
700 // PyModule
701 //------------------------------------------------------------------------------
702 
703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
704     : BaseContextObject(std::move(contextRef)), module(module) {}
705 
706 PyModule::~PyModule() {
707   py::gil_scoped_acquire acquire;
708   auto &liveModules = getContext()->liveModules;
709   assert(liveModules.count(module.ptr) == 1 &&
710          "destroying module not in live map");
711   liveModules.erase(module.ptr);
712   mlirModuleDestroy(module);
713 }
714 
715 PyModuleRef PyModule::forModule(MlirModule module) {
716   MlirContext context = mlirModuleGetContext(module);
717   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
718 
719   py::gil_scoped_acquire acquire;
720   auto &liveModules = contextRef->liveModules;
721   auto it = liveModules.find(module.ptr);
722   if (it == liveModules.end()) {
723     // Create.
724     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
725     // Note that the default return value policy on cast is automatic_reference,
726     // which does not take ownership (delete will not be called).
727     // Just be explicit.
728     py::object pyRef =
729         py::cast(unownedModule, py::return_value_policy::take_ownership);
730     unownedModule->handle = pyRef;
731     liveModules[module.ptr] =
732         std::make_pair(unownedModule->handle, unownedModule);
733     return PyModuleRef(unownedModule, std::move(pyRef));
734   }
735   // Use existing.
736   PyModule *existing = it->second.second;
737   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
738   return PyModuleRef(existing, std::move(pyRef));
739 }
740 
741 py::object PyModule::createFromCapsule(py::object capsule) {
742   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
743   if (mlirModuleIsNull(rawModule))
744     throw py::error_already_set();
745   return forModule(rawModule).releaseObject();
746 }
747 
748 py::object PyModule::getCapsule() {
749   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
750 }
751 
752 //------------------------------------------------------------------------------
753 // PyOperation
754 //------------------------------------------------------------------------------
755 
756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
757     : BaseContextObject(std::move(contextRef)), operation(operation) {}
758 
759 PyOperation::~PyOperation() {
760   // If the operation has already been invalidated there is nothing to do.
761   if (!valid)
762     return;
763   auto &liveOperations = getContext()->liveOperations;
764   assert(liveOperations.count(operation.ptr) == 1 &&
765          "destroying operation not in live map");
766   liveOperations.erase(operation.ptr);
767   if (!isAttached()) {
768     mlirOperationDestroy(operation);
769   }
770 }
771 
772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773                                            MlirOperation operation,
774                                            py::object parentKeepAlive) {
775   auto &liveOperations = contextRef->liveOperations;
776   // Create.
777   PyOperation *unownedOperation =
778       new PyOperation(std::move(contextRef), operation);
779   // Note that the default return value policy on cast is automatic_reference,
780   // which does not take ownership (delete will not be called).
781   // Just be explicit.
782   py::object pyRef =
783       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784   unownedOperation->handle = pyRef;
785   if (parentKeepAlive) {
786     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787   }
788   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789   return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791 
792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793                                          MlirOperation operation,
794                                          py::object parentKeepAlive) {
795   auto &liveOperations = contextRef->liveOperations;
796   auto it = liveOperations.find(operation.ptr);
797   if (it == liveOperations.end()) {
798     // Create.
799     return createInstance(std::move(contextRef), operation,
800                           std::move(parentKeepAlive));
801   }
802   // Use existing.
803   PyOperation *existing = it->second.second;
804   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805   return PyOperationRef(existing, std::move(pyRef));
806 }
807 
808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809                                            MlirOperation operation,
810                                            py::object parentKeepAlive) {
811   auto &liveOperations = contextRef->liveOperations;
812   assert(liveOperations.count(operation.ptr) == 0 &&
813          "cannot create detached operation that already exists");
814   (void)liveOperations;
815 
816   PyOperationRef created = createInstance(std::move(contextRef), operation,
817                                           std::move(parentKeepAlive));
818   created->attached = false;
819   return created;
820 }
821 
822 void PyOperation::checkValid() const {
823   if (!valid) {
824     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825   }
826 }
827 
828 void PyOperationBase::print(py::object fileObject, bool binary,
829                             llvm::Optional<int64_t> largeElementsLimit,
830                             bool enableDebugInfo, bool prettyDebugInfo,
831                             bool printGenericOpForm, bool useLocalScope) {
832   PyOperation &operation = getOperation();
833   operation.checkValid();
834   if (fileObject.is_none())
835     fileObject = py::module::import("sys").attr("stdout");
836 
837   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838     fileObject.attr("write")("// Verification failed, printing generic form\n");
839     printGenericOpForm = true;
840   }
841 
842   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843   if (largeElementsLimit)
844     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845   if (enableDebugInfo)
846     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847   if (printGenericOpForm)
848     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849 
850   PyFileAccumulator accum(fileObject, binary);
851   py::gil_scoped_release();
852   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853                               accum.getUserData());
854   mlirOpPrintingFlagsDestroy(flags);
855 }
856 
857 py::object PyOperationBase::getAsm(bool binary,
858                                    llvm::Optional<int64_t> largeElementsLimit,
859                                    bool enableDebugInfo, bool prettyDebugInfo,
860                                    bool printGenericOpForm,
861                                    bool useLocalScope) {
862   py::object fileObject;
863   if (binary) {
864     fileObject = py::module::import("io").attr("BytesIO")();
865   } else {
866     fileObject = py::module::import("io").attr("StringIO")();
867   }
868   print(fileObject, /*binary=*/binary,
869         /*largeElementsLimit=*/largeElementsLimit,
870         /*enableDebugInfo=*/enableDebugInfo,
871         /*prettyDebugInfo=*/prettyDebugInfo,
872         /*printGenericOpForm=*/printGenericOpForm,
873         /*useLocalScope=*/useLocalScope);
874 
875   return fileObject.attr("getvalue")();
876 }
877 
878 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
879   checkValid();
880   if (!isAttached())
881     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
882   MlirOperation operation = mlirOperationGetParentOperation(get());
883   if (mlirOperationIsNull(operation))
884     return {};
885   return PyOperation::forOperation(getContext(), operation);
886 }
887 
888 PyBlock PyOperation::getBlock() {
889   checkValid();
890   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
891   MlirBlock block = mlirOperationGetBlock(get());
892   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
893   assert(parentOperation && "Operation has no parent");
894   return PyBlock{std::move(*parentOperation), block};
895 }
896 
897 py::object PyOperation::getCapsule() {
898   checkValid();
899   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
900 }
901 
902 py::object PyOperation::createFromCapsule(py::object capsule) {
903   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
904   if (mlirOperationIsNull(rawOperation))
905     throw py::error_already_set();
906   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
907   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
908       .releaseObject();
909 }
910 
911 py::object PyOperation::create(
912     std::string name, llvm::Optional<std::vector<PyType *>> results,
913     llvm::Optional<std::vector<PyValue *>> operands,
914     llvm::Optional<py::dict> attributes,
915     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
916     DefaultingPyLocation location, py::object maybeIp) {
917   llvm::SmallVector<MlirValue, 4> mlirOperands;
918   llvm::SmallVector<MlirType, 4> mlirResults;
919   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
920   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
921 
922   // General parameter validation.
923   if (regions < 0)
924     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
925 
926   // Unpack/validate operands.
927   if (operands) {
928     mlirOperands.reserve(operands->size());
929     for (PyValue *operand : *operands) {
930       if (!operand)
931         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
932       mlirOperands.push_back(operand->get());
933     }
934   }
935 
936   // Unpack/validate results.
937   if (results) {
938     mlirResults.reserve(results->size());
939     for (PyType *result : *results) {
940       // TODO: Verify result type originate from the same context.
941       if (!result)
942         throw SetPyError(PyExc_ValueError, "result type cannot be None");
943       mlirResults.push_back(*result);
944     }
945   }
946   // Unpack/validate attributes.
947   if (attributes) {
948     mlirAttributes.reserve(attributes->size());
949     for (auto &it : *attributes) {
950       std::string key;
951       try {
952         key = it.first.cast<std::string>();
953       } catch (py::cast_error &err) {
954         std::string msg = "Invalid attribute key (not a string) when "
955                           "attempting to create the operation \"" +
956                           name + "\" (" + err.what() + ")";
957         throw py::cast_error(msg);
958       }
959       try {
960         auto &attribute = it.second.cast<PyAttribute &>();
961         // TODO: Verify attribute originates from the same context.
962         mlirAttributes.emplace_back(std::move(key), attribute);
963       } catch (py::reference_cast_error &) {
964         // This exception seems thrown when the value is "None".
965         std::string msg =
966             "Found an invalid (`None`?) attribute value for the key \"" + key +
967             "\" when attempting to create the operation \"" + name + "\"";
968         throw py::cast_error(msg);
969       } catch (py::cast_error &err) {
970         std::string msg = "Invalid attribute value for the key \"" + key +
971                           "\" when attempting to create the operation \"" +
972                           name + "\" (" + err.what() + ")";
973         throw py::cast_error(msg);
974       }
975     }
976   }
977   // Unpack/validate successors.
978   if (successors) {
979     mlirSuccessors.reserve(successors->size());
980     for (auto *successor : *successors) {
981       // TODO: Verify successor originate from the same context.
982       if (!successor)
983         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
984       mlirSuccessors.push_back(successor->get());
985     }
986   }
987 
988   // Apply unpacked/validated to the operation state. Beyond this
989   // point, exceptions cannot be thrown or else the state will leak.
990   MlirOperationState state =
991       mlirOperationStateGet(toMlirStringRef(name), location);
992   if (!mlirOperands.empty())
993     mlirOperationStateAddOperands(&state, mlirOperands.size(),
994                                   mlirOperands.data());
995   if (!mlirResults.empty())
996     mlirOperationStateAddResults(&state, mlirResults.size(),
997                                  mlirResults.data());
998   if (!mlirAttributes.empty()) {
999     // Note that the attribute names directly reference bytes in
1000     // mlirAttributes, so that vector must not be changed from here
1001     // on.
1002     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1003     mlirNamedAttributes.reserve(mlirAttributes.size());
1004     for (auto &it : mlirAttributes)
1005       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1006           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1007                             toMlirStringRef(it.first)),
1008           it.second));
1009     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1010                                     mlirNamedAttributes.data());
1011   }
1012   if (!mlirSuccessors.empty())
1013     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1014                                     mlirSuccessors.data());
1015   if (regions) {
1016     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1017     mlirRegions.resize(regions);
1018     for (int i = 0; i < regions; ++i)
1019       mlirRegions[i] = mlirRegionCreate();
1020     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1021                                       mlirRegions.data());
1022   }
1023 
1024   // Construct the operation.
1025   MlirOperation operation = mlirOperationCreate(&state);
1026   PyOperationRef created =
1027       PyOperation::createDetached(location->getContext(), operation);
1028 
1029   // InsertPoint active?
1030   if (!maybeIp.is(py::cast(false))) {
1031     PyInsertionPoint *ip;
1032     if (maybeIp.is_none()) {
1033       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1034     } else {
1035       ip = py::cast<PyInsertionPoint *>(maybeIp);
1036     }
1037     if (ip)
1038       ip->insert(*created.get());
1039   }
1040 
1041   return created->createOpView();
1042 }
1043 
1044 py::object PyOperation::createOpView() {
1045   checkValid();
1046   MlirIdentifier ident = mlirOperationGetName(get());
1047   MlirStringRef identStr = mlirIdentifierStr(ident);
1048   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1049       StringRef(identStr.data, identStr.length));
1050   if (opViewClass)
1051     return (*opViewClass)(getRef().getObject());
1052   return py::cast(PyOpView(getRef().getObject()));
1053 }
1054 
1055 void PyOperation::erase() {
1056   checkValid();
1057   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1058   // Python reference to a child operation is live. All children should also
1059   // have their `valid` bit set to false.
1060   auto &liveOperations = getContext()->liveOperations;
1061   if (liveOperations.count(operation.ptr))
1062     liveOperations.erase(operation.ptr);
1063   mlirOperationDestroy(operation);
1064   valid = false;
1065 }
1066 
1067 //------------------------------------------------------------------------------
1068 // PyOpView
1069 //------------------------------------------------------------------------------
1070 
1071 py::object
1072 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1073                        py::list operandList,
1074                        llvm::Optional<py::dict> attributes,
1075                        llvm::Optional<std::vector<PyBlock *>> successors,
1076                        llvm::Optional<int> regions,
1077                        DefaultingPyLocation location, py::object maybeIp) {
1078   PyMlirContextRef context = location->getContext();
1079   // Class level operation construction metadata.
1080   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1081   // Operand and result segment specs are either none, which does no
1082   // variadic unpacking, or a list of ints with segment sizes, where each
1083   // element is either a positive number (typically 1 for a scalar) or -1 to
1084   // indicate that it is derived from the length of the same-indexed operand
1085   // or result (implying that it is a list at that position).
1086   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1087   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1088 
1089   std::vector<uint32_t> operandSegmentLengths;
1090   std::vector<uint32_t> resultSegmentLengths;
1091 
1092   // Validate/determine region count.
1093   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1094   int opMinRegionCount = std::get<0>(opRegionSpec);
1095   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1096   if (!regions) {
1097     regions = opMinRegionCount;
1098   }
1099   if (*regions < opMinRegionCount) {
1100     throw py::value_error(
1101         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1102          llvm::Twine(opMinRegionCount) +
1103          " regions but was built with regions=" + llvm::Twine(*regions))
1104             .str());
1105   }
1106   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1107     throw py::value_error(
1108         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1109          llvm::Twine(opMinRegionCount) +
1110          " regions but was built with regions=" + llvm::Twine(*regions))
1111             .str());
1112   }
1113 
1114   // Unpack results.
1115   std::vector<PyType *> resultTypes;
1116   resultTypes.reserve(resultTypeList.size());
1117   if (resultSegmentSpecObj.is_none()) {
1118     // Non-variadic result unpacking.
1119     for (auto it : llvm::enumerate(resultTypeList)) {
1120       try {
1121         resultTypes.push_back(py::cast<PyType *>(it.value()));
1122         if (!resultTypes.back())
1123           throw py::cast_error();
1124       } catch (py::cast_error &err) {
1125         throw py::value_error((llvm::Twine("Result ") +
1126                                llvm::Twine(it.index()) + " of operation \"" +
1127                                name + "\" must be a Type (" + err.what() + ")")
1128                                   .str());
1129       }
1130     }
1131   } else {
1132     // Sized result unpacking.
1133     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1134     if (resultSegmentSpec.size() != resultTypeList.size()) {
1135       throw py::value_error((llvm::Twine("Operation \"") + name +
1136                              "\" requires " +
1137                              llvm::Twine(resultSegmentSpec.size()) +
1138                              "result segments but was provided " +
1139                              llvm::Twine(resultTypeList.size()))
1140                                 .str());
1141     }
1142     resultSegmentLengths.reserve(resultTypeList.size());
1143     for (auto it :
1144          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1145       int segmentSpec = std::get<1>(it.value());
1146       if (segmentSpec == 1 || segmentSpec == 0) {
1147         // Unpack unary element.
1148         try {
1149           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1150           if (resultType) {
1151             resultTypes.push_back(resultType);
1152             resultSegmentLengths.push_back(1);
1153           } else if (segmentSpec == 0) {
1154             // Allowed to be optional.
1155             resultSegmentLengths.push_back(0);
1156           } else {
1157             throw py::cast_error("was None and result is not optional");
1158           }
1159         } catch (py::cast_error &err) {
1160           throw py::value_error((llvm::Twine("Result ") +
1161                                  llvm::Twine(it.index()) + " of operation \"" +
1162                                  name + "\" must be a Type (" + err.what() +
1163                                  ")")
1164                                     .str());
1165         }
1166       } else if (segmentSpec == -1) {
1167         // Unpack sequence by appending.
1168         try {
1169           if (std::get<0>(it.value()).is_none()) {
1170             // Treat it as an empty list.
1171             resultSegmentLengths.push_back(0);
1172           } else {
1173             // Unpack the list.
1174             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1175             for (py::object segmentItem : segment) {
1176               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1177               if (!resultTypes.back()) {
1178                 throw py::cast_error("contained a None item");
1179               }
1180             }
1181             resultSegmentLengths.push_back(segment.size());
1182           }
1183         } catch (std::exception &err) {
1184           // NOTE: Sloppy to be using a catch-all here, but there are at least
1185           // three different unrelated exceptions that can be thrown in the
1186           // above "casts". Just keep the scope above small and catch them all.
1187           throw py::value_error((llvm::Twine("Result ") +
1188                                  llvm::Twine(it.index()) + " of operation \"" +
1189                                  name + "\" must be a Sequence of Types (" +
1190                                  err.what() + ")")
1191                                     .str());
1192         }
1193       } else {
1194         throw py::value_error("Unexpected segment spec");
1195       }
1196     }
1197   }
1198 
1199   // Unpack operands.
1200   std::vector<PyValue *> operands;
1201   operands.reserve(operands.size());
1202   if (operandSegmentSpecObj.is_none()) {
1203     // Non-sized operand unpacking.
1204     for (auto it : llvm::enumerate(operandList)) {
1205       try {
1206         operands.push_back(py::cast<PyValue *>(it.value()));
1207         if (!operands.back())
1208           throw py::cast_error();
1209       } catch (py::cast_error &err) {
1210         throw py::value_error((llvm::Twine("Operand ") +
1211                                llvm::Twine(it.index()) + " of operation \"" +
1212                                name + "\" must be a Value (" + err.what() + ")")
1213                                   .str());
1214       }
1215     }
1216   } else {
1217     // Sized operand unpacking.
1218     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1219     if (operandSegmentSpec.size() != operandList.size()) {
1220       throw py::value_error((llvm::Twine("Operation \"") + name +
1221                              "\" requires " +
1222                              llvm::Twine(operandSegmentSpec.size()) +
1223                              "operand segments but was provided " +
1224                              llvm::Twine(operandList.size()))
1225                                 .str());
1226     }
1227     operandSegmentLengths.reserve(operandList.size());
1228     for (auto it :
1229          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1230       int segmentSpec = std::get<1>(it.value());
1231       if (segmentSpec == 1 || segmentSpec == 0) {
1232         // Unpack unary element.
1233         try {
1234           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1235           if (operandValue) {
1236             operands.push_back(operandValue);
1237             operandSegmentLengths.push_back(1);
1238           } else if (segmentSpec == 0) {
1239             // Allowed to be optional.
1240             operandSegmentLengths.push_back(0);
1241           } else {
1242             throw py::cast_error("was None and operand is not optional");
1243           }
1244         } catch (py::cast_error &err) {
1245           throw py::value_error((llvm::Twine("Operand ") +
1246                                  llvm::Twine(it.index()) + " of operation \"" +
1247                                  name + "\" must be a Value (" + err.what() +
1248                                  ")")
1249                                     .str());
1250         }
1251       } else if (segmentSpec == -1) {
1252         // Unpack sequence by appending.
1253         try {
1254           if (std::get<0>(it.value()).is_none()) {
1255             // Treat it as an empty list.
1256             operandSegmentLengths.push_back(0);
1257           } else {
1258             // Unpack the list.
1259             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1260             for (py::object segmentItem : segment) {
1261               operands.push_back(py::cast<PyValue *>(segmentItem));
1262               if (!operands.back()) {
1263                 throw py::cast_error("contained a None item");
1264               }
1265             }
1266             operandSegmentLengths.push_back(segment.size());
1267           }
1268         } catch (std::exception &err) {
1269           // NOTE: Sloppy to be using a catch-all here, but there are at least
1270           // three different unrelated exceptions that can be thrown in the
1271           // above "casts". Just keep the scope above small and catch them all.
1272           throw py::value_error((llvm::Twine("Operand ") +
1273                                  llvm::Twine(it.index()) + " of operation \"" +
1274                                  name + "\" must be a Sequence of Values (" +
1275                                  err.what() + ")")
1276                                     .str());
1277         }
1278       } else {
1279         throw py::value_error("Unexpected segment spec");
1280       }
1281     }
1282   }
1283 
1284   // Merge operand/result segment lengths into attributes if needed.
1285   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1286     // Dup.
1287     if (attributes) {
1288       attributes = py::dict(*attributes);
1289     } else {
1290       attributes = py::dict();
1291     }
1292     if (attributes->contains("result_segment_sizes") ||
1293         attributes->contains("operand_segment_sizes")) {
1294       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1295                             "'operand_segment_sizes' attribute is unsupported. "
1296                             "Use Operation.create for such low-level access.");
1297     }
1298 
1299     // Add result_segment_sizes attribute.
1300     if (!resultSegmentLengths.empty()) {
1301       int64_t size = resultSegmentLengths.size();
1302       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1303           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1304           resultSegmentLengths.size(), resultSegmentLengths.data());
1305       (*attributes)["result_segment_sizes"] =
1306           PyAttribute(context, segmentLengthAttr);
1307     }
1308 
1309     // Add operand_segment_sizes attribute.
1310     if (!operandSegmentLengths.empty()) {
1311       int64_t size = operandSegmentLengths.size();
1312       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1313           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1314           operandSegmentLengths.size(), operandSegmentLengths.data());
1315       (*attributes)["operand_segment_sizes"] =
1316           PyAttribute(context, segmentLengthAttr);
1317     }
1318   }
1319 
1320   // Delegate to create.
1321   return PyOperation::create(std::move(name),
1322                              /*results=*/std::move(resultTypes),
1323                              /*operands=*/std::move(operands),
1324                              /*attributes=*/std::move(attributes),
1325                              /*successors=*/std::move(successors),
1326                              /*regions=*/*regions, location, maybeIp);
1327 }
1328 
1329 PyOpView::PyOpView(py::object operationObject)
1330     // Casting through the PyOperationBase base-class and then back to the
1331     // Operation lets us accept any PyOperationBase subclass.
1332     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1333       operationObject(operation.getRef().getObject()) {}
1334 
1335 py::object PyOpView::createRawSubclass(py::object userClass) {
1336   // This is... a little gross. The typical pattern is to have a pure python
1337   // class that extends OpView like:
1338   //   class AddFOp(_cext.ir.OpView):
1339   //     def __init__(self, loc, lhs, rhs):
1340   //       operation = loc.context.create_operation(
1341   //           "addf", lhs, rhs, results=[lhs.type])
1342   //       super().__init__(operation)
1343   //
1344   // I.e. The goal of the user facing type is to provide a nice constructor
1345   // that has complete freedom for the op under construction. This is at odds
1346   // with our other desire to sometimes create this object by just passing an
1347   // operation (to initialize the base class). We could do *arg and **kwargs
1348   // munging to try to make it work, but instead, we synthesize a new class
1349   // on the fly which extends this user class (AddFOp in this example) and
1350   // *give it* the base class's __init__ method, thus bypassing the
1351   // intermediate subclass's __init__ method entirely. While slightly,
1352   // underhanded, this is safe/legal because the type hierarchy has not changed
1353   // (we just added a new leaf) and we aren't mucking around with __new__.
1354   // Typically, this new class will be stored on the original as "_Raw" and will
1355   // be used for casts and other things that need a variant of the class that
1356   // is initialized purely from an operation.
1357   py::object parentMetaclass =
1358       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1359   py::dict attributes;
1360   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1361   // now.
1362   //   auto opViewType = py::type::of<PyOpView>();
1363   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1364   attributes["__init__"] = opViewType.attr("__init__");
1365   py::str origName = userClass.attr("__name__");
1366   py::str newName = py::str("_") + origName;
1367   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1368 }
1369 
1370 //------------------------------------------------------------------------------
1371 // PyInsertionPoint.
1372 //------------------------------------------------------------------------------
1373 
1374 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1375 
1376 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1377     : refOperation(beforeOperationBase.getOperation().getRef()),
1378       block((*refOperation)->getBlock()) {}
1379 
1380 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1381   PyOperation &operation = operationBase.getOperation();
1382   if (operation.isAttached())
1383     throw SetPyError(PyExc_ValueError,
1384                      "Attempt to insert operation that is already attached");
1385   block.getParentOperation()->checkValid();
1386   MlirOperation beforeOp = {nullptr};
1387   if (refOperation) {
1388     // Insert before operation.
1389     (*refOperation)->checkValid();
1390     beforeOp = (*refOperation)->get();
1391   } else {
1392     // Insert at end (before null) is only valid if the block does not
1393     // already end in a known terminator (violating this will cause assertion
1394     // failures later).
1395     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1396       throw py::index_error("Cannot insert operation at the end of a block "
1397                             "that already has a terminator. Did you mean to "
1398                             "use 'InsertionPoint.at_block_terminator(block)' "
1399                             "versus 'InsertionPoint(block)'?");
1400     }
1401   }
1402   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1403   operation.setAttached();
1404 }
1405 
1406 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1407   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1408   if (mlirOperationIsNull(firstOp)) {
1409     // Just insert at end.
1410     return PyInsertionPoint(block);
1411   }
1412 
1413   // Insert before first op.
1414   PyOperationRef firstOpRef = PyOperation::forOperation(
1415       block.getParentOperation()->getContext(), firstOp);
1416   return PyInsertionPoint{block, std::move(firstOpRef)};
1417 }
1418 
1419 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1420   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1421   if (mlirOperationIsNull(terminator))
1422     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1423   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1424       block.getParentOperation()->getContext(), terminator);
1425   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1426 }
1427 
1428 py::object PyInsertionPoint::contextEnter() {
1429   return PyThreadContextEntry::pushInsertionPoint(*this);
1430 }
1431 
1432 void PyInsertionPoint::contextExit(pybind11::object excType,
1433                                    pybind11::object excVal,
1434                                    pybind11::object excTb) {
1435   PyThreadContextEntry::popInsertionPoint(*this);
1436 }
1437 
1438 //------------------------------------------------------------------------------
1439 // PyAttribute.
1440 //------------------------------------------------------------------------------
1441 
1442 bool PyAttribute::operator==(const PyAttribute &other) {
1443   return mlirAttributeEqual(attr, other.attr);
1444 }
1445 
1446 py::object PyAttribute::getCapsule() {
1447   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1448 }
1449 
1450 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1451   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1452   if (mlirAttributeIsNull(rawAttr))
1453     throw py::error_already_set();
1454   return PyAttribute(
1455       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1456 }
1457 
1458 //------------------------------------------------------------------------------
1459 // PyNamedAttribute.
1460 //------------------------------------------------------------------------------
1461 
1462 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1463     : ownedName(new std::string(std::move(ownedName))) {
1464   namedAttr = mlirNamedAttributeGet(
1465       mlirIdentifierGet(mlirAttributeGetContext(attr),
1466                         toMlirStringRef(*this->ownedName)),
1467       attr);
1468 }
1469 
1470 //------------------------------------------------------------------------------
1471 // PyType.
1472 //------------------------------------------------------------------------------
1473 
1474 bool PyType::operator==(const PyType &other) {
1475   return mlirTypeEqual(type, other.type);
1476 }
1477 
1478 py::object PyType::getCapsule() {
1479   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1480 }
1481 
1482 PyType PyType::createFromCapsule(py::object capsule) {
1483   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1484   if (mlirTypeIsNull(rawType))
1485     throw py::error_already_set();
1486   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1487                 rawType);
1488 }
1489 
1490 //------------------------------------------------------------------------------
1491 // PyValue and subclases.
1492 //------------------------------------------------------------------------------
1493 
1494 pybind11::object PyValue::getCapsule() {
1495   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1496 }
1497 
1498 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1499   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1500   if (mlirValueIsNull(value))
1501     throw py::error_already_set();
1502   MlirOperation owner;
1503   if (mlirValueIsAOpResult(value))
1504     owner = mlirOpResultGetOwner(value);
1505   if (mlirValueIsABlockArgument(value))
1506     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1507   if (mlirOperationIsNull(owner))
1508     throw py::error_already_set();
1509   MlirContext ctx = mlirOperationGetContext(owner);
1510   PyOperationRef ownerRef =
1511       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1512   return PyValue(ownerRef, value);
1513 }
1514 
1515 namespace {
1516 /// CRTP base class for Python MLIR values that subclass Value and should be
1517 /// castable from it. The value hierarchy is one level deep and is not supposed
1518 /// to accommodate other levels unless core MLIR changes.
1519 template <typename DerivedTy>
1520 class PyConcreteValue : public PyValue {
1521 public:
1522   // Derived classes must define statics for:
1523   //   IsAFunctionTy isaFunction
1524   //   const char *pyClassName
1525   // and redefine bindDerived.
1526   using ClassTy = py::class_<DerivedTy, PyValue>;
1527   using IsAFunctionTy = bool (*)(MlirValue);
1528 
1529   PyConcreteValue() = default;
1530   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1531       : PyValue(operationRef, value) {}
1532   PyConcreteValue(PyValue &orig)
1533       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1534 
1535   /// Attempts to cast the original value to the derived type and throws on
1536   /// type mismatches.
1537   static MlirValue castFrom(PyValue &orig) {
1538     if (!DerivedTy::isaFunction(orig.get())) {
1539       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1540       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1541                                              DerivedTy::pyClassName +
1542                                              " (from " + origRepr + ")");
1543     }
1544     return orig.get();
1545   }
1546 
1547   /// Binds the Python module objects to functions of this class.
1548   static void bind(py::module &m) {
1549     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1550     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1551     cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
1552       return DerivedTy::isaFunction(otherValue);
1553     });
1554     DerivedTy::bindDerived(cls);
1555   }
1556 
1557   /// Implemented by derived classes to add methods to the Python subclass.
1558   static void bindDerived(ClassTy &m) {}
1559 };
1560 
1561 /// Python wrapper for MlirBlockArgument.
1562 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1563 public:
1564   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1565   static constexpr const char *pyClassName = "BlockArgument";
1566   using PyConcreteValue::PyConcreteValue;
1567 
1568   static void bindDerived(ClassTy &c) {
1569     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1570       return PyBlock(self.getParentOperation(),
1571                      mlirBlockArgumentGetOwner(self.get()));
1572     });
1573     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1574       return mlirBlockArgumentGetArgNumber(self.get());
1575     });
1576     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1577       return mlirBlockArgumentSetType(self.get(), type);
1578     });
1579   }
1580 };
1581 
1582 /// Python wrapper for MlirOpResult.
1583 class PyOpResult : public PyConcreteValue<PyOpResult> {
1584 public:
1585   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1586   static constexpr const char *pyClassName = "OpResult";
1587   using PyConcreteValue::PyConcreteValue;
1588 
1589   static void bindDerived(ClassTy &c) {
1590     c.def_property_readonly("owner", [](PyOpResult &self) {
1591       assert(
1592           mlirOperationEqual(self.getParentOperation()->get(),
1593                              mlirOpResultGetOwner(self.get())) &&
1594           "expected the owner of the value in Python to match that in the IR");
1595       return self.getParentOperation().getObject();
1596     });
1597     c.def_property_readonly("result_number", [](PyOpResult &self) {
1598       return mlirOpResultGetResultNumber(self.get());
1599     });
1600   }
1601 };
1602 
1603 /// Returns the list of types of the values held by container.
1604 template <typename Container>
1605 static std::vector<PyType> getValueTypes(Container &container,
1606                                          PyMlirContextRef &context) {
1607   std::vector<PyType> result;
1608   result.reserve(container.getNumElements());
1609   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1610     result.push_back(
1611         PyType(context, mlirValueGetType(container.getElement(i).get())));
1612   }
1613   return result;
1614 }
1615 
1616 /// A list of block arguments. Internally, these are stored as consecutive
1617 /// elements, random access is cheap. The argument list is associated with the
1618 /// operation that contains the block (detached blocks are not allowed in
1619 /// Python bindings) and extends its lifetime.
1620 class PyBlockArgumentList
1621     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1622 public:
1623   static constexpr const char *pyClassName = "BlockArgumentList";
1624 
1625   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1626                       intptr_t startIndex = 0, intptr_t length = -1,
1627                       intptr_t step = 1)
1628       : Sliceable(startIndex,
1629                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1630                   step),
1631         operation(std::move(operation)), block(block) {}
1632 
1633   /// Returns the number of arguments in the list.
1634   intptr_t getNumElements() {
1635     operation->checkValid();
1636     return mlirBlockGetNumArguments(block);
1637   }
1638 
1639   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1640   PyBlockArgument getElement(intptr_t pos) {
1641     MlirValue argument = mlirBlockGetArgument(block, pos);
1642     return PyBlockArgument(operation, argument);
1643   }
1644 
1645   /// Returns a sublist of this list.
1646   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1647                             intptr_t step) {
1648     return PyBlockArgumentList(operation, block, startIndex, length, step);
1649   }
1650 
1651   static void bindDerived(ClassTy &c) {
1652     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1653       return getValueTypes(self, self.operation->getContext());
1654     });
1655   }
1656 
1657 private:
1658   PyOperationRef operation;
1659   MlirBlock block;
1660 };
1661 
1662 /// A list of operation operands. Internally, these are stored as consecutive
1663 /// elements, random access is cheap. The result list is associated with the
1664 /// operation whose results these are, and extends the lifetime of this
1665 /// operation.
1666 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1667 public:
1668   static constexpr const char *pyClassName = "OpOperandList";
1669 
1670   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1671                   intptr_t length = -1, intptr_t step = 1)
1672       : Sliceable(startIndex,
1673                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1674                                : length,
1675                   step),
1676         operation(operation) {}
1677 
1678   intptr_t getNumElements() {
1679     operation->checkValid();
1680     return mlirOperationGetNumOperands(operation->get());
1681   }
1682 
1683   PyValue getElement(intptr_t pos) {
1684     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1685     MlirOperation owner;
1686     if (mlirValueIsAOpResult(operand))
1687       owner = mlirOpResultGetOwner(operand);
1688     else if (mlirValueIsABlockArgument(operand))
1689       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1690     else
1691       assert(false && "Value must be an block arg or op result.");
1692     PyOperationRef pyOwner =
1693         PyOperation::forOperation(operation->getContext(), owner);
1694     return PyValue(pyOwner, operand);
1695   }
1696 
1697   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1698     return PyOpOperandList(operation, startIndex, length, step);
1699   }
1700 
1701   void dunderSetItem(intptr_t index, PyValue value) {
1702     index = wrapIndex(index);
1703     mlirOperationSetOperand(operation->get(), index, value.get());
1704   }
1705 
1706   static void bindDerived(ClassTy &c) {
1707     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1708   }
1709 
1710 private:
1711   PyOperationRef operation;
1712 };
1713 
1714 /// A list of operation results. Internally, these are stored as consecutive
1715 /// elements, random access is cheap. The result list is associated with the
1716 /// operation whose results these are, and extends the lifetime of this
1717 /// operation.
1718 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1719 public:
1720   static constexpr const char *pyClassName = "OpResultList";
1721 
1722   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1723                  intptr_t length = -1, intptr_t step = 1)
1724       : Sliceable(startIndex,
1725                   length == -1 ? mlirOperationGetNumResults(operation->get())
1726                                : length,
1727                   step),
1728         operation(operation) {}
1729 
1730   intptr_t getNumElements() {
1731     operation->checkValid();
1732     return mlirOperationGetNumResults(operation->get());
1733   }
1734 
1735   PyOpResult getElement(intptr_t index) {
1736     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1737     return PyOpResult(value);
1738   }
1739 
1740   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1741     return PyOpResultList(operation, startIndex, length, step);
1742   }
1743 
1744   static void bindDerived(ClassTy &c) {
1745     c.def_property_readonly("types", [](PyOpResultList &self) {
1746       return getValueTypes(self, self.operation->getContext());
1747     });
1748   }
1749 
1750 private:
1751   PyOperationRef operation;
1752 };
1753 
1754 /// A list of operation attributes. Can be indexed by name, producing
1755 /// attributes, or by index, producing named attributes.
1756 class PyOpAttributeMap {
1757 public:
1758   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1759 
1760   PyAttribute dunderGetItemNamed(const std::string &name) {
1761     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1762                                                          toMlirStringRef(name));
1763     if (mlirAttributeIsNull(attr)) {
1764       throw SetPyError(PyExc_KeyError,
1765                        "attempt to access a non-existent attribute");
1766     }
1767     return PyAttribute(operation->getContext(), attr);
1768   }
1769 
1770   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1771     if (index < 0 || index >= dunderLen()) {
1772       throw SetPyError(PyExc_IndexError,
1773                        "attempt to access out of bounds attribute");
1774     }
1775     MlirNamedAttribute namedAttr =
1776         mlirOperationGetAttribute(operation->get(), index);
1777     return PyNamedAttribute(
1778         namedAttr.attribute,
1779         std::string(mlirIdentifierStr(namedAttr.name).data));
1780   }
1781 
1782   void dunderSetItem(const std::string &name, PyAttribute attr) {
1783     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1784                                     attr);
1785   }
1786 
1787   void dunderDelItem(const std::string &name) {
1788     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1789                                                      toMlirStringRef(name));
1790     if (!removed)
1791       throw SetPyError(PyExc_KeyError,
1792                        "attempt to delete a non-existent attribute");
1793   }
1794 
1795   intptr_t dunderLen() {
1796     return mlirOperationGetNumAttributes(operation->get());
1797   }
1798 
1799   bool dunderContains(const std::string &name) {
1800     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1801         operation->get(), toMlirStringRef(name)));
1802   }
1803 
1804   static void bind(py::module &m) {
1805     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1806         .def("__contains__", &PyOpAttributeMap::dunderContains)
1807         .def("__len__", &PyOpAttributeMap::dunderLen)
1808         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1809         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1810         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1811         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1812   }
1813 
1814 private:
1815   PyOperationRef operation;
1816 };
1817 
1818 } // end namespace
1819 
1820 //------------------------------------------------------------------------------
1821 // Populates the core exports of the 'ir' submodule.
1822 //------------------------------------------------------------------------------
1823 
1824 void mlir::python::populateIRCore(py::module &m) {
1825   //----------------------------------------------------------------------------
1826   // Mapping of MlirContext.
1827   //----------------------------------------------------------------------------
1828   py::class_<PyMlirContext>(m, "Context", py::module_local())
1829       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1830       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1831       .def("_get_context_again",
1832            [](PyMlirContext &self) {
1833              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1834              return ref.releaseObject();
1835            })
1836       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1837       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1838       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1839                              &PyMlirContext::getCapsule)
1840       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1841       .def("__enter__", &PyMlirContext::contextEnter)
1842       .def("__exit__", &PyMlirContext::contextExit)
1843       .def_property_readonly_static(
1844           "current",
1845           [](py::object & /*class*/) {
1846             auto *context = PyThreadContextEntry::getDefaultContext();
1847             if (!context)
1848               throw SetPyError(PyExc_ValueError, "No current Context");
1849             return context;
1850           },
1851           "Gets the Context bound to the current thread or raises ValueError")
1852       .def_property_readonly(
1853           "dialects",
1854           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1855           "Gets a container for accessing dialects by name")
1856       .def_property_readonly(
1857           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1858           "Alias for 'dialect'")
1859       .def(
1860           "get_dialect_descriptor",
1861           [=](PyMlirContext &self, std::string &name) {
1862             MlirDialect dialect = mlirContextGetOrLoadDialect(
1863                 self.get(), {name.data(), name.size()});
1864             if (mlirDialectIsNull(dialect)) {
1865               throw SetPyError(PyExc_ValueError,
1866                                Twine("Dialect '") + name + "' not found");
1867             }
1868             return PyDialectDescriptor(self.getRef(), dialect);
1869           },
1870           "Gets or loads a dialect by name, returning its descriptor object")
1871       .def_property(
1872           "allow_unregistered_dialects",
1873           [](PyMlirContext &self) -> bool {
1874             return mlirContextGetAllowUnregisteredDialects(self.get());
1875           },
1876           [](PyMlirContext &self, bool value) {
1877             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1878           })
1879       .def("enable_multithreading",
1880            [](PyMlirContext &self, bool enable) {
1881              mlirContextEnableMultithreading(self.get(), enable);
1882            })
1883       .def("is_registered_operation",
1884            [](PyMlirContext &self, std::string &name) {
1885              return mlirContextIsRegisteredOperation(
1886                  self.get(), MlirStringRef{name.data(), name.size()});
1887            });
1888 
1889   //----------------------------------------------------------------------------
1890   // Mapping of PyDialectDescriptor
1891   //----------------------------------------------------------------------------
1892   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1893       .def_property_readonly("namespace",
1894                              [](PyDialectDescriptor &self) {
1895                                MlirStringRef ns =
1896                                    mlirDialectGetNamespace(self.get());
1897                                return py::str(ns.data, ns.length);
1898                              })
1899       .def("__repr__", [](PyDialectDescriptor &self) {
1900         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1901         std::string repr("<DialectDescriptor ");
1902         repr.append(ns.data, ns.length);
1903         repr.append(">");
1904         return repr;
1905       });
1906 
1907   //----------------------------------------------------------------------------
1908   // Mapping of PyDialects
1909   //----------------------------------------------------------------------------
1910   py::class_<PyDialects>(m, "Dialects", py::module_local())
1911       .def("__getitem__",
1912            [=](PyDialects &self, std::string keyName) {
1913              MlirDialect dialect =
1914                  self.getDialectForKey(keyName, /*attrError=*/false);
1915              py::object descriptor =
1916                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1917              return createCustomDialectWrapper(keyName, std::move(descriptor));
1918            })
1919       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1920         MlirDialect dialect =
1921             self.getDialectForKey(attrName, /*attrError=*/true);
1922         py::object descriptor =
1923             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1924         return createCustomDialectWrapper(attrName, std::move(descriptor));
1925       });
1926 
1927   //----------------------------------------------------------------------------
1928   // Mapping of PyDialect
1929   //----------------------------------------------------------------------------
1930   py::class_<PyDialect>(m, "Dialect", py::module_local())
1931       .def(py::init<py::object>(), "descriptor")
1932       .def_property_readonly(
1933           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1934       .def("__repr__", [](py::object self) {
1935         auto clazz = self.attr("__class__");
1936         return py::str("<Dialect ") +
1937                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1938                clazz.attr("__module__") + py::str(".") +
1939                clazz.attr("__name__") + py::str(")>");
1940       });
1941 
1942   //----------------------------------------------------------------------------
1943   // Mapping of Location
1944   //----------------------------------------------------------------------------
1945   py::class_<PyLocation>(m, "Location", py::module_local())
1946       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1947       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1948       .def("__enter__", &PyLocation::contextEnter)
1949       .def("__exit__", &PyLocation::contextExit)
1950       .def("__eq__",
1951            [](PyLocation &self, PyLocation &other) -> bool {
1952              return mlirLocationEqual(self, other);
1953            })
1954       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1955       .def_property_readonly_static(
1956           "current",
1957           [](py::object & /*class*/) {
1958             auto *loc = PyThreadContextEntry::getDefaultLocation();
1959             if (!loc)
1960               throw SetPyError(PyExc_ValueError, "No current Location");
1961             return loc;
1962           },
1963           "Gets the Location bound to the current thread or raises ValueError")
1964       .def_static(
1965           "unknown",
1966           [](DefaultingPyMlirContext context) {
1967             return PyLocation(context->getRef(),
1968                               mlirLocationUnknownGet(context->get()));
1969           },
1970           py::arg("context") = py::none(),
1971           "Gets a Location representing an unknown location")
1972       .def_static(
1973           "callsite",
1974           [](PyLocation callee, const std::vector<PyLocation> &frames,
1975              DefaultingPyMlirContext context) {
1976             if (frames.empty())
1977               throw py::value_error("No caller frames provided");
1978             MlirLocation caller = frames.back().get();
1979             for (PyLocation frame :
1980                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
1981               caller = mlirLocationCallSiteGet(frame.get(), caller);
1982             return PyLocation(context->getRef(),
1983                               mlirLocationCallSiteGet(callee.get(), caller));
1984           },
1985           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
1986           kContextGetCallSiteLocationDocstring)
1987       .def_static(
1988           "file",
1989           [](std::string filename, int line, int col,
1990              DefaultingPyMlirContext context) {
1991             return PyLocation(
1992                 context->getRef(),
1993                 mlirLocationFileLineColGet(
1994                     context->get(), toMlirStringRef(filename), line, col));
1995           },
1996           py::arg("filename"), py::arg("line"), py::arg("col"),
1997           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1998       .def_static(
1999           "name",
2000           [](std::string name, llvm::Optional<PyLocation> childLoc,
2001              DefaultingPyMlirContext context) {
2002             return PyLocation(
2003                 context->getRef(),
2004                 mlirLocationNameGet(
2005                     context->get(), toMlirStringRef(name),
2006                     childLoc ? childLoc->get()
2007                              : mlirLocationUnknownGet(context->get())));
2008           },
2009           py::arg("name"), py::arg("childLoc") = py::none(),
2010           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2011       .def_property_readonly(
2012           "context",
2013           [](PyLocation &self) { return self.getContext().getObject(); },
2014           "Context that owns the Location")
2015       .def("__repr__", [](PyLocation &self) {
2016         PyPrintAccumulator printAccum;
2017         mlirLocationPrint(self, printAccum.getCallback(),
2018                           printAccum.getUserData());
2019         return printAccum.join();
2020       });
2021 
2022   //----------------------------------------------------------------------------
2023   // Mapping of Module
2024   //----------------------------------------------------------------------------
2025   py::class_<PyModule>(m, "Module", py::module_local())
2026       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2027       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2028       .def_static(
2029           "parse",
2030           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2031             MlirModule module = mlirModuleCreateParse(
2032                 context->get(), toMlirStringRef(moduleAsm));
2033             // TODO: Rework error reporting once diagnostic engine is exposed
2034             // in C API.
2035             if (mlirModuleIsNull(module)) {
2036               throw SetPyError(
2037                   PyExc_ValueError,
2038                   "Unable to parse module assembly (see diagnostics)");
2039             }
2040             return PyModule::forModule(module).releaseObject();
2041           },
2042           py::arg("asm"), py::arg("context") = py::none(),
2043           kModuleParseDocstring)
2044       .def_static(
2045           "create",
2046           [](DefaultingPyLocation loc) {
2047             MlirModule module = mlirModuleCreateEmpty(loc);
2048             return PyModule::forModule(module).releaseObject();
2049           },
2050           py::arg("loc") = py::none(), "Creates an empty module")
2051       .def_property_readonly(
2052           "context",
2053           [](PyModule &self) { return self.getContext().getObject(); },
2054           "Context that created the Module")
2055       .def_property_readonly(
2056           "operation",
2057           [](PyModule &self) {
2058             return PyOperation::forOperation(self.getContext(),
2059                                              mlirModuleGetOperation(self.get()),
2060                                              self.getRef().releaseObject())
2061                 .releaseObject();
2062           },
2063           "Accesses the module as an operation")
2064       .def_property_readonly(
2065           "body",
2066           [](PyModule &self) {
2067             PyOperationRef module_op = PyOperation::forOperation(
2068                 self.getContext(), mlirModuleGetOperation(self.get()),
2069                 self.getRef().releaseObject());
2070             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2071             return returnBlock;
2072           },
2073           "Return the block for this module")
2074       .def(
2075           "dump",
2076           [](PyModule &self) {
2077             mlirOperationDump(mlirModuleGetOperation(self.get()));
2078           },
2079           kDumpDocstring)
2080       .def(
2081           "__str__",
2082           [](PyModule &self) {
2083             MlirOperation operation = mlirModuleGetOperation(self.get());
2084             PyPrintAccumulator printAccum;
2085             mlirOperationPrint(operation, printAccum.getCallback(),
2086                                printAccum.getUserData());
2087             return printAccum.join();
2088           },
2089           kOperationStrDunderDocstring);
2090 
2091   //----------------------------------------------------------------------------
2092   // Mapping of Operation.
2093   //----------------------------------------------------------------------------
2094   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2095       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2096                              [](PyOperationBase &self) {
2097                                return self.getOperation().getCapsule();
2098                              })
2099       .def("__eq__",
2100            [](PyOperationBase &self, PyOperationBase &other) {
2101              return &self.getOperation() == &other.getOperation();
2102            })
2103       .def("__eq__",
2104            [](PyOperationBase &self, py::object other) { return false; })
2105       .def_property_readonly("attributes",
2106                              [](PyOperationBase &self) {
2107                                return PyOpAttributeMap(
2108                                    self.getOperation().getRef());
2109                              })
2110       .def_property_readonly("operands",
2111                              [](PyOperationBase &self) {
2112                                return PyOpOperandList(
2113                                    self.getOperation().getRef());
2114                              })
2115       .def_property_readonly("regions",
2116                              [](PyOperationBase &self) {
2117                                return PyRegionList(
2118                                    self.getOperation().getRef());
2119                              })
2120       .def_property_readonly(
2121           "results",
2122           [](PyOperationBase &self) {
2123             return PyOpResultList(self.getOperation().getRef());
2124           },
2125           "Returns the list of Operation results.")
2126       .def_property_readonly(
2127           "result",
2128           [](PyOperationBase &self) {
2129             auto &operation = self.getOperation();
2130             auto numResults = mlirOperationGetNumResults(operation);
2131             if (numResults != 1) {
2132               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2133               throw SetPyError(
2134                   PyExc_ValueError,
2135                   Twine("Cannot call .result on operation ") +
2136                       StringRef(name.data, name.length) + " which has " +
2137                       Twine(numResults) +
2138                       " results (it is only valid for operations with a "
2139                       "single result)");
2140             }
2141             return PyOpResult(operation.getRef(),
2142                               mlirOperationGetResult(operation, 0));
2143           },
2144           "Shortcut to get an op result if it has only one (throws an error "
2145           "otherwise).")
2146       .def("__iter__",
2147            [](PyOperationBase &self) {
2148              return PyRegionIterator(self.getOperation().getRef());
2149            })
2150       .def(
2151           "__str__",
2152           [](PyOperationBase &self) {
2153             return self.getAsm(/*binary=*/false,
2154                                /*largeElementsLimit=*/llvm::None,
2155                                /*enableDebugInfo=*/false,
2156                                /*prettyDebugInfo=*/false,
2157                                /*printGenericOpForm=*/false,
2158                                /*useLocalScope=*/false);
2159           },
2160           "Returns the assembly form of the operation.")
2161       .def("print", &PyOperationBase::print,
2162            // Careful: Lots of arguments must match up with print method.
2163            py::arg("file") = py::none(), py::arg("binary") = false,
2164            py::arg("large_elements_limit") = py::none(),
2165            py::arg("enable_debug_info") = false,
2166            py::arg("pretty_debug_info") = false,
2167            py::arg("print_generic_op_form") = false,
2168            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2169       .def("get_asm", &PyOperationBase::getAsm,
2170            // Careful: Lots of arguments must match up with get_asm method.
2171            py::arg("binary") = false,
2172            py::arg("large_elements_limit") = py::none(),
2173            py::arg("enable_debug_info") = false,
2174            py::arg("pretty_debug_info") = false,
2175            py::arg("print_generic_op_form") = false,
2176            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2177       .def(
2178           "verify",
2179           [](PyOperationBase &self) {
2180             return mlirOperationVerify(self.getOperation());
2181           },
2182           "Verify the operation and return true if it passes, false if it "
2183           "fails.");
2184 
2185   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2186       .def_static("create", &PyOperation::create, py::arg("name"),
2187                   py::arg("results") = py::none(),
2188                   py::arg("operands") = py::none(),
2189                   py::arg("attributes") = py::none(),
2190                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2191                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2192                   kOperationCreateDocstring)
2193       .def_property_readonly("parent",
2194                              [](PyOperation &self) -> py::object {
2195                                auto parent = self.getParentOperation();
2196                                if (parent)
2197                                  return parent->getObject();
2198                                return py::none();
2199                              })
2200       .def("erase", &PyOperation::erase)
2201       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2202                              &PyOperation::getCapsule)
2203       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2204       .def_property_readonly("name",
2205                              [](PyOperation &self) {
2206                                self.checkValid();
2207                                MlirOperation operation = self.get();
2208                                MlirStringRef name = mlirIdentifierStr(
2209                                    mlirOperationGetName(operation));
2210                                return py::str(name.data, name.length);
2211                              })
2212       .def_property_readonly(
2213           "context",
2214           [](PyOperation &self) {
2215             self.checkValid();
2216             return self.getContext().getObject();
2217           },
2218           "Context that owns the Operation")
2219       .def_property_readonly("opview", &PyOperation::createOpView);
2220 
2221   auto opViewClass =
2222       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2223           .def(py::init<py::object>())
2224           .def_property_readonly("operation", &PyOpView::getOperationObject)
2225           .def_property_readonly(
2226               "context",
2227               [](PyOpView &self) {
2228                 return self.getOperation().getContext().getObject();
2229               },
2230               "Context that owns the Operation")
2231           .def("__str__", [](PyOpView &self) {
2232             return py::str(self.getOperationObject());
2233           });
2234   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2235   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2236   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2237   opViewClass.attr("build_generic") = classmethod(
2238       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2239       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2240       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2241       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2242       "Builds a specific, generated OpView based on class level attributes.");
2243 
2244   //----------------------------------------------------------------------------
2245   // Mapping of PyRegion.
2246   //----------------------------------------------------------------------------
2247   py::class_<PyRegion>(m, "Region", py::module_local())
2248       .def_property_readonly(
2249           "blocks",
2250           [](PyRegion &self) {
2251             return PyBlockList(self.getParentOperation(), self.get());
2252           },
2253           "Returns a forward-optimized sequence of blocks.")
2254       .def_property_readonly(
2255           "owner",
2256           [](PyRegion &self) {
2257             return self.getParentOperation()->createOpView();
2258           },
2259           "Returns the operation owning this region.")
2260       .def(
2261           "__iter__",
2262           [](PyRegion &self) {
2263             self.checkValid();
2264             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2265             return PyBlockIterator(self.getParentOperation(), firstBlock);
2266           },
2267           "Iterates over blocks in the region.")
2268       .def("__eq__",
2269            [](PyRegion &self, PyRegion &other) {
2270              return self.get().ptr == other.get().ptr;
2271            })
2272       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2273 
2274   //----------------------------------------------------------------------------
2275   // Mapping of PyBlock.
2276   //----------------------------------------------------------------------------
2277   py::class_<PyBlock>(m, "Block", py::module_local())
2278       .def_property_readonly(
2279           "owner",
2280           [](PyBlock &self) {
2281             return self.getParentOperation()->createOpView();
2282           },
2283           "Returns the owning operation of this block.")
2284       .def_property_readonly(
2285           "region",
2286           [](PyBlock &self) {
2287             MlirRegion region = mlirBlockGetParentRegion(self.get());
2288             return PyRegion(self.getParentOperation(), region);
2289           },
2290           "Returns the owning region of this block.")
2291       .def_property_readonly(
2292           "arguments",
2293           [](PyBlock &self) {
2294             return PyBlockArgumentList(self.getParentOperation(), self.get());
2295           },
2296           "Returns a list of block arguments.")
2297       .def_property_readonly(
2298           "operations",
2299           [](PyBlock &self) {
2300             return PyOperationList(self.getParentOperation(), self.get());
2301           },
2302           "Returns a forward-optimized sequence of operations.")
2303       .def_static(
2304           "create_at_start",
2305           [](PyRegion &parent, py::list pyArgTypes) {
2306             parent.checkValid();
2307             llvm::SmallVector<MlirType, 4> argTypes;
2308             argTypes.reserve(pyArgTypes.size());
2309             for (auto &pyArg : pyArgTypes) {
2310               argTypes.push_back(pyArg.cast<PyType &>());
2311             }
2312 
2313             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2314             mlirRegionInsertOwnedBlock(parent, 0, block);
2315             return PyBlock(parent.getParentOperation(), block);
2316           },
2317           py::arg("parent"), py::arg("pyArgTypes") = py::list(),
2318           "Creates and returns a new Block at the beginning of the given "
2319           "region (with given argument types).")
2320       .def(
2321           "create_before",
2322           [](PyBlock &self, py::args pyArgTypes) {
2323             self.checkValid();
2324             llvm::SmallVector<MlirType, 4> argTypes;
2325             argTypes.reserve(pyArgTypes.size());
2326             for (auto &pyArg : pyArgTypes) {
2327               argTypes.push_back(pyArg.cast<PyType &>());
2328             }
2329 
2330             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2331             MlirRegion region = mlirBlockGetParentRegion(self.get());
2332             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2333             return PyBlock(self.getParentOperation(), block);
2334           },
2335           "Creates and returns a new Block before this block "
2336           "(with given argument types).")
2337       .def(
2338           "create_after",
2339           [](PyBlock &self, py::args pyArgTypes) {
2340             self.checkValid();
2341             llvm::SmallVector<MlirType, 4> argTypes;
2342             argTypes.reserve(pyArgTypes.size());
2343             for (auto &pyArg : pyArgTypes) {
2344               argTypes.push_back(pyArg.cast<PyType &>());
2345             }
2346 
2347             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2348             MlirRegion region = mlirBlockGetParentRegion(self.get());
2349             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2350             return PyBlock(self.getParentOperation(), block);
2351           },
2352           "Creates and returns a new Block after this block "
2353           "(with given argument types).")
2354       .def(
2355           "__iter__",
2356           [](PyBlock &self) {
2357             self.checkValid();
2358             MlirOperation firstOperation =
2359                 mlirBlockGetFirstOperation(self.get());
2360             return PyOperationIterator(self.getParentOperation(),
2361                                        firstOperation);
2362           },
2363           "Iterates over operations in the block.")
2364       .def("__eq__",
2365            [](PyBlock &self, PyBlock &other) {
2366              return self.get().ptr == other.get().ptr;
2367            })
2368       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2369       .def(
2370           "__str__",
2371           [](PyBlock &self) {
2372             self.checkValid();
2373             PyPrintAccumulator printAccum;
2374             mlirBlockPrint(self.get(), printAccum.getCallback(),
2375                            printAccum.getUserData());
2376             return printAccum.join();
2377           },
2378           "Returns the assembly form of the block.");
2379 
2380   //----------------------------------------------------------------------------
2381   // Mapping of PyInsertionPoint.
2382   //----------------------------------------------------------------------------
2383 
2384   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2385       .def(py::init<PyBlock &>(), py::arg("block"),
2386            "Inserts after the last operation but still inside the block.")
2387       .def("__enter__", &PyInsertionPoint::contextEnter)
2388       .def("__exit__", &PyInsertionPoint::contextExit)
2389       .def_property_readonly_static(
2390           "current",
2391           [](py::object & /*class*/) {
2392             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2393             if (!ip)
2394               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2395             return ip;
2396           },
2397           "Gets the InsertionPoint bound to the current thread or raises "
2398           "ValueError if none has been set")
2399       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2400            "Inserts before a referenced operation.")
2401       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2402                   py::arg("block"), "Inserts at the beginning of the block.")
2403       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2404                   py::arg("block"), "Inserts before the block terminator.")
2405       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2406            "Inserts an operation.")
2407       .def_property_readonly(
2408           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2409           "Returns the block that this InsertionPoint points to.");
2410 
2411   //----------------------------------------------------------------------------
2412   // Mapping of PyAttribute.
2413   //----------------------------------------------------------------------------
2414   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2415       // Delegate to the PyAttribute copy constructor, which will also lifetime
2416       // extend the backing context which owns the MlirAttribute.
2417       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2418            "Casts the passed attribute to the generic Attribute")
2419       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2420                              &PyAttribute::getCapsule)
2421       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2422       .def_static(
2423           "parse",
2424           [](std::string attrSpec, DefaultingPyMlirContext context) {
2425             MlirAttribute type = mlirAttributeParseGet(
2426                 context->get(), toMlirStringRef(attrSpec));
2427             // TODO: Rework error reporting once diagnostic engine is exposed
2428             // in C API.
2429             if (mlirAttributeIsNull(type)) {
2430               throw SetPyError(PyExc_ValueError,
2431                                Twine("Unable to parse attribute: '") +
2432                                    attrSpec + "'");
2433             }
2434             return PyAttribute(context->getRef(), type);
2435           },
2436           py::arg("asm"), py::arg("context") = py::none(),
2437           "Parses an attribute from an assembly form")
2438       .def_property_readonly(
2439           "context",
2440           [](PyAttribute &self) { return self.getContext().getObject(); },
2441           "Context that owns the Attribute")
2442       .def_property_readonly("type",
2443                              [](PyAttribute &self) {
2444                                return PyType(self.getContext()->getRef(),
2445                                              mlirAttributeGetType(self));
2446                              })
2447       .def(
2448           "get_named",
2449           [](PyAttribute &self, std::string name) {
2450             return PyNamedAttribute(self, std::move(name));
2451           },
2452           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2453       .def("__eq__",
2454            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2455       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2456       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2457       .def(
2458           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2459           kDumpDocstring)
2460       .def(
2461           "__str__",
2462           [](PyAttribute &self) {
2463             PyPrintAccumulator printAccum;
2464             mlirAttributePrint(self, printAccum.getCallback(),
2465                                printAccum.getUserData());
2466             return printAccum.join();
2467           },
2468           "Returns the assembly form of the Attribute.")
2469       .def("__repr__", [](PyAttribute &self) {
2470         // Generally, assembly formats are not printed for __repr__ because
2471         // this can cause exceptionally long debug output and exceptions.
2472         // However, attribute values are generally considered useful and are
2473         // printed. This may need to be re-evaluated if debug dumps end up
2474         // being excessive.
2475         PyPrintAccumulator printAccum;
2476         printAccum.parts.append("Attribute(");
2477         mlirAttributePrint(self, printAccum.getCallback(),
2478                            printAccum.getUserData());
2479         printAccum.parts.append(")");
2480         return printAccum.join();
2481       });
2482 
2483   //----------------------------------------------------------------------------
2484   // Mapping of PyNamedAttribute
2485   //----------------------------------------------------------------------------
2486   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2487       .def("__repr__",
2488            [](PyNamedAttribute &self) {
2489              PyPrintAccumulator printAccum;
2490              printAccum.parts.append("NamedAttribute(");
2491              printAccum.parts.append(
2492                  mlirIdentifierStr(self.namedAttr.name).data);
2493              printAccum.parts.append("=");
2494              mlirAttributePrint(self.namedAttr.attribute,
2495                                 printAccum.getCallback(),
2496                                 printAccum.getUserData());
2497              printAccum.parts.append(")");
2498              return printAccum.join();
2499            })
2500       .def_property_readonly(
2501           "name",
2502           [](PyNamedAttribute &self) {
2503             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2504                            mlirIdentifierStr(self.namedAttr.name).length);
2505           },
2506           "The name of the NamedAttribute binding")
2507       .def_property_readonly(
2508           "attr",
2509           [](PyNamedAttribute &self) {
2510             // TODO: When named attribute is removed/refactored, also remove
2511             // this constructor (it does an inefficient table lookup).
2512             auto contextRef = PyMlirContext::forContext(
2513                 mlirAttributeGetContext(self.namedAttr.attribute));
2514             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2515           },
2516           py::keep_alive<0, 1>(),
2517           "The underlying generic attribute of the NamedAttribute binding");
2518 
2519   //----------------------------------------------------------------------------
2520   // Mapping of PyType.
2521   //----------------------------------------------------------------------------
2522   py::class_<PyType>(m, "Type", py::module_local())
2523       // Delegate to the PyType copy constructor, which will also lifetime
2524       // extend the backing context which owns the MlirType.
2525       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2526            "Casts the passed type to the generic Type")
2527       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2528       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2529       .def_static(
2530           "parse",
2531           [](std::string typeSpec, DefaultingPyMlirContext context) {
2532             MlirType type =
2533                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2534             // TODO: Rework error reporting once diagnostic engine is exposed
2535             // in C API.
2536             if (mlirTypeIsNull(type)) {
2537               throw SetPyError(PyExc_ValueError,
2538                                Twine("Unable to parse type: '") + typeSpec +
2539                                    "'");
2540             }
2541             return PyType(context->getRef(), type);
2542           },
2543           py::arg("asm"), py::arg("context") = py::none(),
2544           kContextParseTypeDocstring)
2545       .def_property_readonly(
2546           "context", [](PyType &self) { return self.getContext().getObject(); },
2547           "Context that owns the Type")
2548       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2549       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2550       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2551       .def(
2552           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2553       .def(
2554           "__str__",
2555           [](PyType &self) {
2556             PyPrintAccumulator printAccum;
2557             mlirTypePrint(self, printAccum.getCallback(),
2558                           printAccum.getUserData());
2559             return printAccum.join();
2560           },
2561           "Returns the assembly form of the type.")
2562       .def("__repr__", [](PyType &self) {
2563         // Generally, assembly formats are not printed for __repr__ because
2564         // this can cause exceptionally long debug output and exceptions.
2565         // However, types are an exception as they typically have compact
2566         // assembly forms and printing them is useful.
2567         PyPrintAccumulator printAccum;
2568         printAccum.parts.append("Type(");
2569         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2570         printAccum.parts.append(")");
2571         return printAccum.join();
2572       });
2573 
2574   //----------------------------------------------------------------------------
2575   // Mapping of Value.
2576   //----------------------------------------------------------------------------
2577   py::class_<PyValue>(m, "Value", py::module_local())
2578       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2579       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2580       .def_property_readonly(
2581           "context",
2582           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2583           "Context in which the value lives.")
2584       .def(
2585           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2586           kDumpDocstring)
2587       .def_property_readonly(
2588           "owner",
2589           [](PyValue &self) {
2590             assert(mlirOperationEqual(self.getParentOperation()->get(),
2591                                       mlirOpResultGetOwner(self.get())) &&
2592                    "expected the owner of the value in Python to match that in "
2593                    "the IR");
2594             return self.getParentOperation().getObject();
2595           })
2596       .def("__eq__",
2597            [](PyValue &self, PyValue &other) {
2598              return self.get().ptr == other.get().ptr;
2599            })
2600       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2601       .def(
2602           "__str__",
2603           [](PyValue &self) {
2604             PyPrintAccumulator printAccum;
2605             printAccum.parts.append("Value(");
2606             mlirValuePrint(self.get(), printAccum.getCallback(),
2607                            printAccum.getUserData());
2608             printAccum.parts.append(")");
2609             return printAccum.join();
2610           },
2611           kValueDunderStrDocstring)
2612       .def_property_readonly("type", [](PyValue &self) {
2613         return PyType(self.getParentOperation()->getContext(),
2614                       mlirValueGetType(self.get()));
2615       });
2616   PyBlockArgument::bind(m);
2617   PyOpResult::bind(m);
2618 
2619   // Container bindings.
2620   PyBlockArgumentList::bind(m);
2621   PyBlockIterator::bind(m);
2622   PyBlockList::bind(m);
2623   PyOperationIterator::bind(m);
2624   PyOperationList::bind(m);
2625   PyOpAttributeMap::bind(m);
2626   PyOpOperandList::bind(m);
2627   PyOpResultList::bind(m);
2628   PyRegionIterator::bind(m);
2629   PyRegionList::bind(m);
2630 
2631   // Debug bindings.
2632   PyGlobalDebugFlag::bind(m);
2633 }
2634