1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96   assume_verified: By default, if not printing generic form, the verifier
97     will be run and if it fails, generic form will be printed with a comment
98     about failed verification. While a reasonable default for interactive use,
99     for systematic use, it is often better for the caller to verify explicitly
100     and report failures in a more robust fashion. Set this to True if doing this
101     in order to avoid running a redundant verification. If the IR is actually
102     invalid, behavior is undefined.
103 )";
104 
105 static const char kOperationGetAsmDocstring[] =
106     R"(Gets the assembly form of the operation with all options available.
107 
108 Args:
109   binary: Whether to return a bytes (True) or str (False) object. Defaults to
110     False.
111   ... others ...: See the print() method for common keyword arguments for
112     configuring the printout.
113 Returns:
114   Either a bytes or str object, depending on the setting of the 'binary'
115   argument.
116 )";
117 
118 static const char kOperationStrDunderDocstring[] =
119     R"(Gets the assembly form of the operation with default options.
120 
121 If more advanced control over the assembly formatting or I/O options is needed,
122 use the dedicated print or get_asm method, which supports keyword arguments to
123 customize behavior.
124 )";
125 
126 static const char kDumpDocstring[] =
127     R"(Dumps a debug representation of the object to stderr.)";
128 
129 static const char kAppendBlockDocstring[] =
130     R"(Appends a new block, with argument types as positional args.
131 
132 Returns:
133   The created block.
134 )";
135 
136 static const char kValueDunderStrDocstring[] =
137     R"(Returns the string form of the value.
138 
139 If the value is a block argument, this is the assembly form of its type and the
140 position in the argument list. If the value is an operation result, this is
141 equivalent to printing the operation that produced it.
142 )";
143 
144 //------------------------------------------------------------------------------
145 // Utilities.
146 //------------------------------------------------------------------------------
147 
148 /// Helper for creating an @classmethod.
149 template <class Func, typename... Args>
150 py::object classmethod(Func f, Args... args) {
151   py::object cf = py::cpp_function(f, args...);
152   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
153 }
154 
155 static py::object
156 createCustomDialectWrapper(const std::string &dialectNamespace,
157                            py::object dialectDescriptor) {
158   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
159   if (!dialectClass) {
160     // Use the base class.
161     return py::cast(PyDialect(std::move(dialectDescriptor)));
162   }
163 
164   // Create the custom implementation.
165   return (*dialectClass)(std::move(dialectDescriptor));
166 }
167 
168 static MlirStringRef toMlirStringRef(const std::string &s) {
169   return mlirStringRefCreate(s.data(), s.size());
170 }
171 
172 /// Wrapper for the global LLVM debugging flag.
173 struct PyGlobalDebugFlag {
174   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
175 
176   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
177 
178   static void bind(py::module &m) {
179     // Debug flags.
180     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
181         .def_property_static("flag", &PyGlobalDebugFlag::get,
182                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
183   }
184 };
185 
186 //------------------------------------------------------------------------------
187 // Collections.
188 //------------------------------------------------------------------------------
189 
190 namespace {
191 
192 class PyRegionIterator {
193 public:
194   PyRegionIterator(PyOperationRef operation)
195       : operation(std::move(operation)) {}
196 
197   PyRegionIterator &dunderIter() { return *this; }
198 
199   PyRegion dunderNext() {
200     operation->checkValid();
201     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
202       throw py::stop_iteration();
203     }
204     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
205     return PyRegion(operation, region);
206   }
207 
208   static void bind(py::module &m) {
209     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
210         .def("__iter__", &PyRegionIterator::dunderIter)
211         .def("__next__", &PyRegionIterator::dunderNext);
212   }
213 
214 private:
215   PyOperationRef operation;
216   int nextIndex = 0;
217 };
218 
219 /// Regions of an op are fixed length and indexed numerically so are represented
220 /// with a sequence-like container.
221 class PyRegionList {
222 public:
223   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
224 
225   intptr_t dunderLen() {
226     operation->checkValid();
227     return mlirOperationGetNumRegions(operation->get());
228   }
229 
230   PyRegion dunderGetItem(intptr_t index) {
231     // dunderLen checks validity.
232     if (index < 0 || index >= dunderLen()) {
233       throw SetPyError(PyExc_IndexError,
234                        "attempt to access out of bounds region");
235     }
236     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
237     return PyRegion(operation, region);
238   }
239 
240   static void bind(py::module &m) {
241     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
242         .def("__len__", &PyRegionList::dunderLen)
243         .def("__getitem__", &PyRegionList::dunderGetItem);
244   }
245 
246 private:
247   PyOperationRef operation;
248 };
249 
250 class PyBlockIterator {
251 public:
252   PyBlockIterator(PyOperationRef operation, MlirBlock next)
253       : operation(std::move(operation)), next(next) {}
254 
255   PyBlockIterator &dunderIter() { return *this; }
256 
257   PyBlock dunderNext() {
258     operation->checkValid();
259     if (mlirBlockIsNull(next)) {
260       throw py::stop_iteration();
261     }
262 
263     PyBlock returnBlock(operation, next);
264     next = mlirBlockGetNextInRegion(next);
265     return returnBlock;
266   }
267 
268   static void bind(py::module &m) {
269     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
270         .def("__iter__", &PyBlockIterator::dunderIter)
271         .def("__next__", &PyBlockIterator::dunderNext);
272   }
273 
274 private:
275   PyOperationRef operation;
276   MlirBlock next;
277 };
278 
279 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
280 /// we present them as a more full-featured list-like container but optimize
281 /// it for forward iteration. Blocks are always owned by a region.
282 class PyBlockList {
283 public:
284   PyBlockList(PyOperationRef operation, MlirRegion region)
285       : operation(std::move(operation)), region(region) {}
286 
287   PyBlockIterator dunderIter() {
288     operation->checkValid();
289     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
290   }
291 
292   intptr_t dunderLen() {
293     operation->checkValid();
294     intptr_t count = 0;
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       count += 1;
298       block = mlirBlockGetNextInRegion(block);
299     }
300     return count;
301   }
302 
303   PyBlock dunderGetItem(intptr_t index) {
304     operation->checkValid();
305     if (index < 0) {
306       throw SetPyError(PyExc_IndexError,
307                        "attempt to access out of bounds block");
308     }
309     MlirBlock block = mlirRegionGetFirstBlock(region);
310     while (!mlirBlockIsNull(block)) {
311       if (index == 0) {
312         return PyBlock(operation, block);
313       }
314       block = mlirBlockGetNextInRegion(block);
315       index -= 1;
316     }
317     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
318   }
319 
320   PyBlock appendBlock(py::args pyArgTypes) {
321     operation->checkValid();
322     llvm::SmallVector<MlirType, 4> argTypes;
323     argTypes.reserve(pyArgTypes.size());
324     for (auto &pyArg : pyArgTypes) {
325       argTypes.push_back(pyArg.cast<PyType &>());
326     }
327 
328     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
329     mlirRegionAppendOwnedBlock(region, block);
330     return PyBlock(operation, block);
331   }
332 
333   static void bind(py::module &m) {
334     py::class_<PyBlockList>(m, "BlockList", py::module_local())
335         .def("__getitem__", &PyBlockList::dunderGetItem)
336         .def("__iter__", &PyBlockList::dunderIter)
337         .def("__len__", &PyBlockList::dunderLen)
338         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
339   }
340 
341 private:
342   PyOperationRef operation;
343   MlirRegion region;
344 };
345 
346 class PyOperationIterator {
347 public:
348   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
349       : parentOperation(std::move(parentOperation)), next(next) {}
350 
351   PyOperationIterator &dunderIter() { return *this; }
352 
353   py::object dunderNext() {
354     parentOperation->checkValid();
355     if (mlirOperationIsNull(next)) {
356       throw py::stop_iteration();
357     }
358 
359     PyOperationRef returnOperation =
360         PyOperation::forOperation(parentOperation->getContext(), next);
361     next = mlirOperationGetNextInBlock(next);
362     return returnOperation->createOpView();
363   }
364 
365   static void bind(py::module &m) {
366     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
367         .def("__iter__", &PyOperationIterator::dunderIter)
368         .def("__next__", &PyOperationIterator::dunderNext);
369   }
370 
371 private:
372   PyOperationRef parentOperation;
373   MlirOperation next;
374 };
375 
376 /// Operations are exposed by the C-API as a forward-only linked list. In
377 /// Python, we present them as a more full-featured list-like container but
378 /// optimize it for forward iteration. Iterable operations are always owned
379 /// by a block.
380 class PyOperationList {
381 public:
382   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
383       : parentOperation(std::move(parentOperation)), block(block) {}
384 
385   PyOperationIterator dunderIter() {
386     parentOperation->checkValid();
387     return PyOperationIterator(parentOperation,
388                                mlirBlockGetFirstOperation(block));
389   }
390 
391   intptr_t dunderLen() {
392     parentOperation->checkValid();
393     intptr_t count = 0;
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       count += 1;
397       childOp = mlirOperationGetNextInBlock(childOp);
398     }
399     return count;
400   }
401 
402   py::object dunderGetItem(intptr_t index) {
403     parentOperation->checkValid();
404     if (index < 0) {
405       throw SetPyError(PyExc_IndexError,
406                        "attempt to access out of bounds operation");
407     }
408     MlirOperation childOp = mlirBlockGetFirstOperation(block);
409     while (!mlirOperationIsNull(childOp)) {
410       if (index == 0) {
411         return PyOperation::forOperation(parentOperation->getContext(), childOp)
412             ->createOpView();
413       }
414       childOp = mlirOperationGetNextInBlock(childOp);
415       index -= 1;
416     }
417     throw SetPyError(PyExc_IndexError,
418                      "attempt to access out of bounds operation");
419   }
420 
421   static void bind(py::module &m) {
422     py::class_<PyOperationList>(m, "OperationList", py::module_local())
423         .def("__getitem__", &PyOperationList::dunderGetItem)
424         .def("__iter__", &PyOperationList::dunderIter)
425         .def("__len__", &PyOperationList::dunderLen);
426   }
427 
428 private:
429   PyOperationRef parentOperation;
430   MlirBlock block;
431 };
432 
433 } // namespace
434 
435 //------------------------------------------------------------------------------
436 // PyMlirContext
437 //------------------------------------------------------------------------------
438 
439 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
440   py::gil_scoped_acquire acquire;
441   auto &liveContexts = getLiveContexts();
442   liveContexts[context.ptr] = this;
443 }
444 
445 PyMlirContext::~PyMlirContext() {
446   // Note that the only public way to construct an instance is via the
447   // forContext method, which always puts the associated handle into
448   // liveContexts.
449   py::gil_scoped_acquire acquire;
450   getLiveContexts().erase(context.ptr);
451   mlirContextDestroy(context);
452 }
453 
454 py::object PyMlirContext::getCapsule() {
455   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
456 }
457 
458 py::object PyMlirContext::createFromCapsule(py::object capsule) {
459   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
460   if (mlirContextIsNull(rawContext))
461     throw py::error_already_set();
462   return forContext(rawContext).releaseObject();
463 }
464 
465 PyMlirContext *PyMlirContext::createNewContextForInit() {
466   MlirContext context = mlirContextCreate();
467   mlirRegisterAllDialects(context);
468   return new PyMlirContext(context);
469 }
470 
471 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
472   py::gil_scoped_acquire acquire;
473   auto &liveContexts = getLiveContexts();
474   auto it = liveContexts.find(context.ptr);
475   if (it == liveContexts.end()) {
476     // Create.
477     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
478     py::object pyRef = py::cast(unownedContextWrapper);
479     assert(pyRef && "cast to py::object failed");
480     liveContexts[context.ptr] = unownedContextWrapper;
481     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
482   }
483   // Use existing.
484   py::object pyRef = py::cast(it->second);
485   return PyMlirContextRef(it->second, std::move(pyRef));
486 }
487 
488 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
489   static LiveContextMap liveContexts;
490   return liveContexts;
491 }
492 
493 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
494 
495 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
496 
497 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
498 
499 pybind11::object PyMlirContext::contextEnter() {
500   return PyThreadContextEntry::pushContext(*this);
501 }
502 
503 void PyMlirContext::contextExit(pybind11::object excType,
504                                 pybind11::object excVal,
505                                 pybind11::object excTb) {
506   PyThreadContextEntry::popContext(*this);
507 }
508 
509 PyMlirContext &DefaultingPyMlirContext::resolve() {
510   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
511   if (!context) {
512     throw SetPyError(
513         PyExc_RuntimeError,
514         "An MLIR function requires a Context but none was provided in the call "
515         "or from the surrounding environment. Either pass to the function with "
516         "a 'context=' argument or establish a default using 'with Context():'");
517   }
518   return *context;
519 }
520 
521 //------------------------------------------------------------------------------
522 // PyThreadContextEntry management
523 //------------------------------------------------------------------------------
524 
525 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
526   static thread_local std::vector<PyThreadContextEntry> stack;
527   return stack;
528 }
529 
530 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
531   auto &stack = getStack();
532   if (stack.empty())
533     return nullptr;
534   return &stack.back();
535 }
536 
537 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
538                                 py::object insertionPoint,
539                                 py::object location) {
540   auto &stack = getStack();
541   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
542                      std::move(location));
543   // If the new stack has more than one entry and the context of the new top
544   // entry matches the previous, copy the insertionPoint and location from the
545   // previous entry if missing from the new top entry.
546   if (stack.size() > 1) {
547     auto &prev = *(stack.rbegin() + 1);
548     auto &current = stack.back();
549     if (current.context.is(prev.context)) {
550       // Default non-context objects from the previous entry.
551       if (!current.insertionPoint)
552         current.insertionPoint = prev.insertionPoint;
553       if (!current.location)
554         current.location = prev.location;
555     }
556   }
557 }
558 
559 PyMlirContext *PyThreadContextEntry::getContext() {
560   if (!context)
561     return nullptr;
562   return py::cast<PyMlirContext *>(context);
563 }
564 
565 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
566   if (!insertionPoint)
567     return nullptr;
568   return py::cast<PyInsertionPoint *>(insertionPoint);
569 }
570 
571 PyLocation *PyThreadContextEntry::getLocation() {
572   if (!location)
573     return nullptr;
574   return py::cast<PyLocation *>(location);
575 }
576 
577 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
578   auto *tos = getTopOfStack();
579   return tos ? tos->getContext() : nullptr;
580 }
581 
582 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
583   auto *tos = getTopOfStack();
584   return tos ? tos->getInsertionPoint() : nullptr;
585 }
586 
587 PyLocation *PyThreadContextEntry::getDefaultLocation() {
588   auto *tos = getTopOfStack();
589   return tos ? tos->getLocation() : nullptr;
590 }
591 
592 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
593   py::object contextObj = py::cast(context);
594   push(FrameKind::Context, /*context=*/contextObj,
595        /*insertionPoint=*/py::object(),
596        /*location=*/py::object());
597   return contextObj;
598 }
599 
600 void PyThreadContextEntry::popContext(PyMlirContext &context) {
601   auto &stack = getStack();
602   if (stack.empty())
603     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
604   auto &tos = stack.back();
605   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
606     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
607   stack.pop_back();
608 }
609 
610 py::object
611 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
612   py::object contextObj =
613       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
614   py::object insertionPointObj = py::cast(insertionPoint);
615   push(FrameKind::InsertionPoint,
616        /*context=*/contextObj,
617        /*insertionPoint=*/insertionPointObj,
618        /*location=*/py::object());
619   return insertionPointObj;
620 }
621 
622 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
623   auto &stack = getStack();
624   if (stack.empty())
625     throw SetPyError(PyExc_RuntimeError,
626                      "Unbalanced InsertionPoint enter/exit");
627   auto &tos = stack.back();
628   if (tos.frameKind != FrameKind::InsertionPoint &&
629       tos.getInsertionPoint() != &insertionPoint)
630     throw SetPyError(PyExc_RuntimeError,
631                      "Unbalanced InsertionPoint enter/exit");
632   stack.pop_back();
633 }
634 
635 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
636   py::object contextObj = location.getContext().getObject();
637   py::object locationObj = py::cast(location);
638   push(FrameKind::Location, /*context=*/contextObj,
639        /*insertionPoint=*/py::object(),
640        /*location=*/locationObj);
641   return locationObj;
642 }
643 
644 void PyThreadContextEntry::popLocation(PyLocation &location) {
645   auto &stack = getStack();
646   if (stack.empty())
647     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
648   auto &tos = stack.back();
649   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
650     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
651   stack.pop_back();
652 }
653 
654 //------------------------------------------------------------------------------
655 // PyDialect, PyDialectDescriptor, PyDialects
656 //------------------------------------------------------------------------------
657 
658 MlirDialect PyDialects::getDialectForKey(const std::string &key,
659                                          bool attrError) {
660   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
661                                                     {key.data(), key.size()});
662   if (mlirDialectIsNull(dialect)) {
663     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
664                      Twine("Dialect '") + key + "' not found");
665   }
666   return dialect;
667 }
668 
669 //------------------------------------------------------------------------------
670 // PyLocation
671 //------------------------------------------------------------------------------
672 
673 py::object PyLocation::getCapsule() {
674   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
675 }
676 
677 PyLocation PyLocation::createFromCapsule(py::object capsule) {
678   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
679   if (mlirLocationIsNull(rawLoc))
680     throw py::error_already_set();
681   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
682                     rawLoc);
683 }
684 
685 py::object PyLocation::contextEnter() {
686   return PyThreadContextEntry::pushLocation(*this);
687 }
688 
689 void PyLocation::contextExit(py::object excType, py::object excVal,
690                              py::object excTb) {
691   PyThreadContextEntry::popLocation(*this);
692 }
693 
694 PyLocation &DefaultingPyLocation::resolve() {
695   auto *location = PyThreadContextEntry::getDefaultLocation();
696   if (!location) {
697     throw SetPyError(
698         PyExc_RuntimeError,
699         "An MLIR function requires a Location but none was provided in the "
700         "call or from the surrounding environment. Either pass to the function "
701         "with a 'loc=' argument or establish a default using 'with loc:'");
702   }
703   return *location;
704 }
705 
706 //------------------------------------------------------------------------------
707 // PyModule
708 //------------------------------------------------------------------------------
709 
710 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
711     : BaseContextObject(std::move(contextRef)), module(module) {}
712 
713 PyModule::~PyModule() {
714   py::gil_scoped_acquire acquire;
715   auto &liveModules = getContext()->liveModules;
716   assert(liveModules.count(module.ptr) == 1 &&
717          "destroying module not in live map");
718   liveModules.erase(module.ptr);
719   mlirModuleDestroy(module);
720 }
721 
722 PyModuleRef PyModule::forModule(MlirModule module) {
723   MlirContext context = mlirModuleGetContext(module);
724   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
725 
726   py::gil_scoped_acquire acquire;
727   auto &liveModules = contextRef->liveModules;
728   auto it = liveModules.find(module.ptr);
729   if (it == liveModules.end()) {
730     // Create.
731     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
732     // Note that the default return value policy on cast is automatic_reference,
733     // which does not take ownership (delete will not be called).
734     // Just be explicit.
735     py::object pyRef =
736         py::cast(unownedModule, py::return_value_policy::take_ownership);
737     unownedModule->handle = pyRef;
738     liveModules[module.ptr] =
739         std::make_pair(unownedModule->handle, unownedModule);
740     return PyModuleRef(unownedModule, std::move(pyRef));
741   }
742   // Use existing.
743   PyModule *existing = it->second.second;
744   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
745   return PyModuleRef(existing, std::move(pyRef));
746 }
747 
748 py::object PyModule::createFromCapsule(py::object capsule) {
749   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
750   if (mlirModuleIsNull(rawModule))
751     throw py::error_already_set();
752   return forModule(rawModule).releaseObject();
753 }
754 
755 py::object PyModule::getCapsule() {
756   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
757 }
758 
759 //------------------------------------------------------------------------------
760 // PyOperation
761 //------------------------------------------------------------------------------
762 
763 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
764     : BaseContextObject(std::move(contextRef)), operation(operation) {}
765 
766 PyOperation::~PyOperation() {
767   // If the operation has already been invalidated there is nothing to do.
768   if (!valid)
769     return;
770   auto &liveOperations = getContext()->liveOperations;
771   assert(liveOperations.count(operation.ptr) == 1 &&
772          "destroying operation not in live map");
773   liveOperations.erase(operation.ptr);
774   if (!isAttached()) {
775     mlirOperationDestroy(operation);
776   }
777 }
778 
779 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
780                                            MlirOperation operation,
781                                            py::object parentKeepAlive) {
782   auto &liveOperations = contextRef->liveOperations;
783   // Create.
784   PyOperation *unownedOperation =
785       new PyOperation(std::move(contextRef), operation);
786   // Note that the default return value policy on cast is automatic_reference,
787   // which does not take ownership (delete will not be called).
788   // Just be explicit.
789   py::object pyRef =
790       py::cast(unownedOperation, py::return_value_policy::take_ownership);
791   unownedOperation->handle = pyRef;
792   if (parentKeepAlive) {
793     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
794   }
795   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
796   return PyOperationRef(unownedOperation, std::move(pyRef));
797 }
798 
799 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
800                                          MlirOperation operation,
801                                          py::object parentKeepAlive) {
802   auto &liveOperations = contextRef->liveOperations;
803   auto it = liveOperations.find(operation.ptr);
804   if (it == liveOperations.end()) {
805     // Create.
806     return createInstance(std::move(contextRef), operation,
807                           std::move(parentKeepAlive));
808   }
809   // Use existing.
810   PyOperation *existing = it->second.second;
811   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
812   return PyOperationRef(existing, std::move(pyRef));
813 }
814 
815 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
816                                            MlirOperation operation,
817                                            py::object parentKeepAlive) {
818   auto &liveOperations = contextRef->liveOperations;
819   assert(liveOperations.count(operation.ptr) == 0 &&
820          "cannot create detached operation that already exists");
821   (void)liveOperations;
822 
823   PyOperationRef created = createInstance(std::move(contextRef), operation,
824                                           std::move(parentKeepAlive));
825   created->attached = false;
826   return created;
827 }
828 
829 void PyOperation::checkValid() const {
830   if (!valid) {
831     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
832   }
833 }
834 
835 void PyOperationBase::print(py::object fileObject, bool binary,
836                             llvm::Optional<int64_t> largeElementsLimit,
837                             bool enableDebugInfo, bool prettyDebugInfo,
838                             bool printGenericOpForm, bool useLocalScope,
839                             bool assumeVerified) {
840   PyOperation &operation = getOperation();
841   operation.checkValid();
842   if (fileObject.is_none())
843     fileObject = py::module::import("sys").attr("stdout");
844 
845   if (!assumeVerified && !printGenericOpForm &&
846       !mlirOperationVerify(operation)) {
847     std::string message("// Verification failed, printing generic form\n");
848     if (binary) {
849       fileObject.attr("write")(py::bytes(message));
850     } else {
851       fileObject.attr("write")(py::str(message));
852     }
853     printGenericOpForm = true;
854   }
855 
856   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
857   if (largeElementsLimit)
858     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
859   if (enableDebugInfo)
860     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
861   if (printGenericOpForm)
862     mlirOpPrintingFlagsPrintGenericOpForm(flags);
863 
864   PyFileAccumulator accum(fileObject, binary);
865   py::gil_scoped_release();
866   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
867                               accum.getUserData());
868   mlirOpPrintingFlagsDestroy(flags);
869 }
870 
871 py::object PyOperationBase::getAsm(bool binary,
872                                    llvm::Optional<int64_t> largeElementsLimit,
873                                    bool enableDebugInfo, bool prettyDebugInfo,
874                                    bool printGenericOpForm, bool useLocalScope,
875                                    bool assumeVerified) {
876   py::object fileObject;
877   if (binary) {
878     fileObject = py::module::import("io").attr("BytesIO")();
879   } else {
880     fileObject = py::module::import("io").attr("StringIO")();
881   }
882   print(fileObject, /*binary=*/binary,
883         /*largeElementsLimit=*/largeElementsLimit,
884         /*enableDebugInfo=*/enableDebugInfo,
885         /*prettyDebugInfo=*/prettyDebugInfo,
886         /*printGenericOpForm=*/printGenericOpForm,
887         /*useLocalScope=*/useLocalScope,
888         /*assumeVerified=*/assumeVerified);
889 
890   return fileObject.attr("getvalue")();
891 }
892 
893 void PyOperationBase::moveAfter(PyOperationBase &other) {
894   PyOperation &operation = getOperation();
895   PyOperation &otherOp = other.getOperation();
896   operation.checkValid();
897   otherOp.checkValid();
898   mlirOperationMoveAfter(operation, otherOp);
899   operation.parentKeepAlive = otherOp.parentKeepAlive;
900 }
901 
902 void PyOperationBase::moveBefore(PyOperationBase &other) {
903   PyOperation &operation = getOperation();
904   PyOperation &otherOp = other.getOperation();
905   operation.checkValid();
906   otherOp.checkValid();
907   mlirOperationMoveBefore(operation, otherOp);
908   operation.parentKeepAlive = otherOp.parentKeepAlive;
909 }
910 
911 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
912   checkValid();
913   if (!isAttached())
914     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
915   MlirOperation operation = mlirOperationGetParentOperation(get());
916   if (mlirOperationIsNull(operation))
917     return {};
918   return PyOperation::forOperation(getContext(), operation);
919 }
920 
921 PyBlock PyOperation::getBlock() {
922   checkValid();
923   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
924   MlirBlock block = mlirOperationGetBlock(get());
925   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
926   assert(parentOperation && "Operation has no parent");
927   return PyBlock{std::move(*parentOperation), block};
928 }
929 
930 py::object PyOperation::getCapsule() {
931   checkValid();
932   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
933 }
934 
935 py::object PyOperation::createFromCapsule(py::object capsule) {
936   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
937   if (mlirOperationIsNull(rawOperation))
938     throw py::error_already_set();
939   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
940   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
941       .releaseObject();
942 }
943 
944 py::object PyOperation::create(
945     std::string name, llvm::Optional<std::vector<PyType *>> results,
946     llvm::Optional<std::vector<PyValue *>> operands,
947     llvm::Optional<py::dict> attributes,
948     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
949     DefaultingPyLocation location, py::object maybeIp) {
950   llvm::SmallVector<MlirValue, 4> mlirOperands;
951   llvm::SmallVector<MlirType, 4> mlirResults;
952   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
954 
955   // General parameter validation.
956   if (regions < 0)
957     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
958 
959   // Unpack/validate operands.
960   if (operands) {
961     mlirOperands.reserve(operands->size());
962     for (PyValue *operand : *operands) {
963       if (!operand)
964         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
965       mlirOperands.push_back(operand->get());
966     }
967   }
968 
969   // Unpack/validate results.
970   if (results) {
971     mlirResults.reserve(results->size());
972     for (PyType *result : *results) {
973       // TODO: Verify result type originate from the same context.
974       if (!result)
975         throw SetPyError(PyExc_ValueError, "result type cannot be None");
976       mlirResults.push_back(*result);
977     }
978   }
979   // Unpack/validate attributes.
980   if (attributes) {
981     mlirAttributes.reserve(attributes->size());
982     for (auto &it : *attributes) {
983       std::string key;
984       try {
985         key = it.first.cast<std::string>();
986       } catch (py::cast_error &err) {
987         std::string msg = "Invalid attribute key (not a string) when "
988                           "attempting to create the operation \"" +
989                           name + "\" (" + err.what() + ")";
990         throw py::cast_error(msg);
991       }
992       try {
993         auto &attribute = it.second.cast<PyAttribute &>();
994         // TODO: Verify attribute originates from the same context.
995         mlirAttributes.emplace_back(std::move(key), attribute);
996       } catch (py::reference_cast_error &) {
997         // This exception seems thrown when the value is "None".
998         std::string msg =
999             "Found an invalid (`None`?) attribute value for the key \"" + key +
1000             "\" when attempting to create the operation \"" + name + "\"";
1001         throw py::cast_error(msg);
1002       } catch (py::cast_error &err) {
1003         std::string msg = "Invalid attribute value for the key \"" + key +
1004                           "\" when attempting to create the operation \"" +
1005                           name + "\" (" + err.what() + ")";
1006         throw py::cast_error(msg);
1007       }
1008     }
1009   }
1010   // Unpack/validate successors.
1011   if (successors) {
1012     mlirSuccessors.reserve(successors->size());
1013     for (auto *successor : *successors) {
1014       // TODO: Verify successor originate from the same context.
1015       if (!successor)
1016         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1017       mlirSuccessors.push_back(successor->get());
1018     }
1019   }
1020 
1021   // Apply unpacked/validated to the operation state. Beyond this
1022   // point, exceptions cannot be thrown or else the state will leak.
1023   MlirOperationState state =
1024       mlirOperationStateGet(toMlirStringRef(name), location);
1025   if (!mlirOperands.empty())
1026     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1027                                   mlirOperands.data());
1028   if (!mlirResults.empty())
1029     mlirOperationStateAddResults(&state, mlirResults.size(),
1030                                  mlirResults.data());
1031   if (!mlirAttributes.empty()) {
1032     // Note that the attribute names directly reference bytes in
1033     // mlirAttributes, so that vector must not be changed from here
1034     // on.
1035     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1036     mlirNamedAttributes.reserve(mlirAttributes.size());
1037     for (auto &it : mlirAttributes)
1038       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1039           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1040                             toMlirStringRef(it.first)),
1041           it.second));
1042     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1043                                     mlirNamedAttributes.data());
1044   }
1045   if (!mlirSuccessors.empty())
1046     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1047                                     mlirSuccessors.data());
1048   if (regions) {
1049     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1050     mlirRegions.resize(regions);
1051     for (int i = 0; i < regions; ++i)
1052       mlirRegions[i] = mlirRegionCreate();
1053     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1054                                       mlirRegions.data());
1055   }
1056 
1057   // Construct the operation.
1058   MlirOperation operation = mlirOperationCreate(&state);
1059   PyOperationRef created =
1060       PyOperation::createDetached(location->getContext(), operation);
1061 
1062   // InsertPoint active?
1063   if (!maybeIp.is(py::cast(false))) {
1064     PyInsertionPoint *ip;
1065     if (maybeIp.is_none()) {
1066       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1067     } else {
1068       ip = py::cast<PyInsertionPoint *>(maybeIp);
1069     }
1070     if (ip)
1071       ip->insert(*created.get());
1072   }
1073 
1074   return created->createOpView();
1075 }
1076 
1077 py::object PyOperation::createOpView() {
1078   checkValid();
1079   MlirIdentifier ident = mlirOperationGetName(get());
1080   MlirStringRef identStr = mlirIdentifierStr(ident);
1081   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1082       StringRef(identStr.data, identStr.length));
1083   if (opViewClass)
1084     return (*opViewClass)(getRef().getObject());
1085   return py::cast(PyOpView(getRef().getObject()));
1086 }
1087 
1088 void PyOperation::erase() {
1089   checkValid();
1090   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1091   // Python reference to a child operation is live. All children should also
1092   // have their `valid` bit set to false.
1093   auto &liveOperations = getContext()->liveOperations;
1094   if (liveOperations.count(operation.ptr))
1095     liveOperations.erase(operation.ptr);
1096   mlirOperationDestroy(operation);
1097   valid = false;
1098 }
1099 
1100 //------------------------------------------------------------------------------
1101 // PyOpView
1102 //------------------------------------------------------------------------------
1103 
1104 py::object
1105 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1106                        py::list operandList,
1107                        llvm::Optional<py::dict> attributes,
1108                        llvm::Optional<std::vector<PyBlock *>> successors,
1109                        llvm::Optional<int> regions,
1110                        DefaultingPyLocation location, py::object maybeIp) {
1111   PyMlirContextRef context = location->getContext();
1112   // Class level operation construction metadata.
1113   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1114   // Operand and result segment specs are either none, which does no
1115   // variadic unpacking, or a list of ints with segment sizes, where each
1116   // element is either a positive number (typically 1 for a scalar) or -1 to
1117   // indicate that it is derived from the length of the same-indexed operand
1118   // or result (implying that it is a list at that position).
1119   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1120   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1121 
1122   std::vector<uint32_t> operandSegmentLengths;
1123   std::vector<uint32_t> resultSegmentLengths;
1124 
1125   // Validate/determine region count.
1126   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1127   int opMinRegionCount = std::get<0>(opRegionSpec);
1128   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1129   if (!regions) {
1130     regions = opMinRegionCount;
1131   }
1132   if (*regions < opMinRegionCount) {
1133     throw py::value_error(
1134         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1135          llvm::Twine(opMinRegionCount) +
1136          " regions but was built with regions=" + llvm::Twine(*regions))
1137             .str());
1138   }
1139   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1140     throw py::value_error(
1141         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1142          llvm::Twine(opMinRegionCount) +
1143          " regions but was built with regions=" + llvm::Twine(*regions))
1144             .str());
1145   }
1146 
1147   // Unpack results.
1148   std::vector<PyType *> resultTypes;
1149   resultTypes.reserve(resultTypeList.size());
1150   if (resultSegmentSpecObj.is_none()) {
1151     // Non-variadic result unpacking.
1152     for (auto it : llvm::enumerate(resultTypeList)) {
1153       try {
1154         resultTypes.push_back(py::cast<PyType *>(it.value()));
1155         if (!resultTypes.back())
1156           throw py::cast_error();
1157       } catch (py::cast_error &err) {
1158         throw py::value_error((llvm::Twine("Result ") +
1159                                llvm::Twine(it.index()) + " of operation \"" +
1160                                name + "\" must be a Type (" + err.what() + ")")
1161                                   .str());
1162       }
1163     }
1164   } else {
1165     // Sized result unpacking.
1166     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1167     if (resultSegmentSpec.size() != resultTypeList.size()) {
1168       throw py::value_error((llvm::Twine("Operation \"") + name +
1169                              "\" requires " +
1170                              llvm::Twine(resultSegmentSpec.size()) +
1171                              " result segments but was provided " +
1172                              llvm::Twine(resultTypeList.size()))
1173                                 .str());
1174     }
1175     resultSegmentLengths.reserve(resultTypeList.size());
1176     for (auto it :
1177          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1178       int segmentSpec = std::get<1>(it.value());
1179       if (segmentSpec == 1 || segmentSpec == 0) {
1180         // Unpack unary element.
1181         try {
1182           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1183           if (resultType) {
1184             resultTypes.push_back(resultType);
1185             resultSegmentLengths.push_back(1);
1186           } else if (segmentSpec == 0) {
1187             // Allowed to be optional.
1188             resultSegmentLengths.push_back(0);
1189           } else {
1190             throw py::cast_error("was None and result is not optional");
1191           }
1192         } catch (py::cast_error &err) {
1193           throw py::value_error((llvm::Twine("Result ") +
1194                                  llvm::Twine(it.index()) + " of operation \"" +
1195                                  name + "\" must be a Type (" + err.what() +
1196                                  ")")
1197                                     .str());
1198         }
1199       } else if (segmentSpec == -1) {
1200         // Unpack sequence by appending.
1201         try {
1202           if (std::get<0>(it.value()).is_none()) {
1203             // Treat it as an empty list.
1204             resultSegmentLengths.push_back(0);
1205           } else {
1206             // Unpack the list.
1207             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1208             for (py::object segmentItem : segment) {
1209               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1210               if (!resultTypes.back()) {
1211                 throw py::cast_error("contained a None item");
1212               }
1213             }
1214             resultSegmentLengths.push_back(segment.size());
1215           }
1216         } catch (std::exception &err) {
1217           // NOTE: Sloppy to be using a catch-all here, but there are at least
1218           // three different unrelated exceptions that can be thrown in the
1219           // above "casts". Just keep the scope above small and catch them all.
1220           throw py::value_error((llvm::Twine("Result ") +
1221                                  llvm::Twine(it.index()) + " of operation \"" +
1222                                  name + "\" must be a Sequence of Types (" +
1223                                  err.what() + ")")
1224                                     .str());
1225         }
1226       } else {
1227         throw py::value_error("Unexpected segment spec");
1228       }
1229     }
1230   }
1231 
1232   // Unpack operands.
1233   std::vector<PyValue *> operands;
1234   operands.reserve(operands.size());
1235   if (operandSegmentSpecObj.is_none()) {
1236     // Non-sized operand unpacking.
1237     for (auto it : llvm::enumerate(operandList)) {
1238       try {
1239         operands.push_back(py::cast<PyValue *>(it.value()));
1240         if (!operands.back())
1241           throw py::cast_error();
1242       } catch (py::cast_error &err) {
1243         throw py::value_error((llvm::Twine("Operand ") +
1244                                llvm::Twine(it.index()) + " of operation \"" +
1245                                name + "\" must be a Value (" + err.what() + ")")
1246                                   .str());
1247       }
1248     }
1249   } else {
1250     // Sized operand unpacking.
1251     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1252     if (operandSegmentSpec.size() != operandList.size()) {
1253       throw py::value_error((llvm::Twine("Operation \"") + name +
1254                              "\" requires " +
1255                              llvm::Twine(operandSegmentSpec.size()) +
1256                              "operand segments but was provided " +
1257                              llvm::Twine(operandList.size()))
1258                                 .str());
1259     }
1260     operandSegmentLengths.reserve(operandList.size());
1261     for (auto it :
1262          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1263       int segmentSpec = std::get<1>(it.value());
1264       if (segmentSpec == 1 || segmentSpec == 0) {
1265         // Unpack unary element.
1266         try {
1267           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1268           if (operandValue) {
1269             operands.push_back(operandValue);
1270             operandSegmentLengths.push_back(1);
1271           } else if (segmentSpec == 0) {
1272             // Allowed to be optional.
1273             operandSegmentLengths.push_back(0);
1274           } else {
1275             throw py::cast_error("was None and operand is not optional");
1276           }
1277         } catch (py::cast_error &err) {
1278           throw py::value_error((llvm::Twine("Operand ") +
1279                                  llvm::Twine(it.index()) + " of operation \"" +
1280                                  name + "\" must be a Value (" + err.what() +
1281                                  ")")
1282                                     .str());
1283         }
1284       } else if (segmentSpec == -1) {
1285         // Unpack sequence by appending.
1286         try {
1287           if (std::get<0>(it.value()).is_none()) {
1288             // Treat it as an empty list.
1289             operandSegmentLengths.push_back(0);
1290           } else {
1291             // Unpack the list.
1292             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1293             for (py::object segmentItem : segment) {
1294               operands.push_back(py::cast<PyValue *>(segmentItem));
1295               if (!operands.back()) {
1296                 throw py::cast_error("contained a None item");
1297               }
1298             }
1299             operandSegmentLengths.push_back(segment.size());
1300           }
1301         } catch (std::exception &err) {
1302           // NOTE: Sloppy to be using a catch-all here, but there are at least
1303           // three different unrelated exceptions that can be thrown in the
1304           // above "casts". Just keep the scope above small and catch them all.
1305           throw py::value_error((llvm::Twine("Operand ") +
1306                                  llvm::Twine(it.index()) + " of operation \"" +
1307                                  name + "\" must be a Sequence of Values (" +
1308                                  err.what() + ")")
1309                                     .str());
1310         }
1311       } else {
1312         throw py::value_error("Unexpected segment spec");
1313       }
1314     }
1315   }
1316 
1317   // Merge operand/result segment lengths into attributes if needed.
1318   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1319     // Dup.
1320     if (attributes) {
1321       attributes = py::dict(*attributes);
1322     } else {
1323       attributes = py::dict();
1324     }
1325     if (attributes->contains("result_segment_sizes") ||
1326         attributes->contains("operand_segment_sizes")) {
1327       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1328                             "'operand_segment_sizes' attribute is unsupported. "
1329                             "Use Operation.create for such low-level access.");
1330     }
1331 
1332     // Add result_segment_sizes attribute.
1333     if (!resultSegmentLengths.empty()) {
1334       int64_t size = resultSegmentLengths.size();
1335       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1336           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1337           resultSegmentLengths.size(), resultSegmentLengths.data());
1338       (*attributes)["result_segment_sizes"] =
1339           PyAttribute(context, segmentLengthAttr);
1340     }
1341 
1342     // Add operand_segment_sizes attribute.
1343     if (!operandSegmentLengths.empty()) {
1344       int64_t size = operandSegmentLengths.size();
1345       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1346           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1347           operandSegmentLengths.size(), operandSegmentLengths.data());
1348       (*attributes)["operand_segment_sizes"] =
1349           PyAttribute(context, segmentLengthAttr);
1350     }
1351   }
1352 
1353   // Delegate to create.
1354   return PyOperation::create(std::move(name),
1355                              /*results=*/std::move(resultTypes),
1356                              /*operands=*/std::move(operands),
1357                              /*attributes=*/std::move(attributes),
1358                              /*successors=*/std::move(successors),
1359                              /*regions=*/*regions, location, maybeIp);
1360 }
1361 
1362 PyOpView::PyOpView(py::object operationObject)
1363     // Casting through the PyOperationBase base-class and then back to the
1364     // Operation lets us accept any PyOperationBase subclass.
1365     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1366       operationObject(operation.getRef().getObject()) {}
1367 
1368 py::object PyOpView::createRawSubclass(py::object userClass) {
1369   // This is... a little gross. The typical pattern is to have a pure python
1370   // class that extends OpView like:
1371   //   class AddFOp(_cext.ir.OpView):
1372   //     def __init__(self, loc, lhs, rhs):
1373   //       operation = loc.context.create_operation(
1374   //           "addf", lhs, rhs, results=[lhs.type])
1375   //       super().__init__(operation)
1376   //
1377   // I.e. The goal of the user facing type is to provide a nice constructor
1378   // that has complete freedom for the op under construction. This is at odds
1379   // with our other desire to sometimes create this object by just passing an
1380   // operation (to initialize the base class). We could do *arg and **kwargs
1381   // munging to try to make it work, but instead, we synthesize a new class
1382   // on the fly which extends this user class (AddFOp in this example) and
1383   // *give it* the base class's __init__ method, thus bypassing the
1384   // intermediate subclass's __init__ method entirely. While slightly,
1385   // underhanded, this is safe/legal because the type hierarchy has not changed
1386   // (we just added a new leaf) and we aren't mucking around with __new__.
1387   // Typically, this new class will be stored on the original as "_Raw" and will
1388   // be used for casts and other things that need a variant of the class that
1389   // is initialized purely from an operation.
1390   py::object parentMetaclass =
1391       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1392   py::dict attributes;
1393   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1394   // now.
1395   //   auto opViewType = py::type::of<PyOpView>();
1396   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1397   attributes["__init__"] = opViewType.attr("__init__");
1398   py::str origName = userClass.attr("__name__");
1399   py::str newName = py::str("_") + origName;
1400   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1401 }
1402 
1403 //------------------------------------------------------------------------------
1404 // PyInsertionPoint.
1405 //------------------------------------------------------------------------------
1406 
1407 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1408 
1409 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1410     : refOperation(beforeOperationBase.getOperation().getRef()),
1411       block((*refOperation)->getBlock()) {}
1412 
1413 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1414   PyOperation &operation = operationBase.getOperation();
1415   if (operation.isAttached())
1416     throw SetPyError(PyExc_ValueError,
1417                      "Attempt to insert operation that is already attached");
1418   block.getParentOperation()->checkValid();
1419   MlirOperation beforeOp = {nullptr};
1420   if (refOperation) {
1421     // Insert before operation.
1422     (*refOperation)->checkValid();
1423     beforeOp = (*refOperation)->get();
1424   } else {
1425     // Insert at end (before null) is only valid if the block does not
1426     // already end in a known terminator (violating this will cause assertion
1427     // failures later).
1428     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1429       throw py::index_error("Cannot insert operation at the end of a block "
1430                             "that already has a terminator. Did you mean to "
1431                             "use 'InsertionPoint.at_block_terminator(block)' "
1432                             "versus 'InsertionPoint(block)'?");
1433     }
1434   }
1435   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1436   operation.setAttached();
1437 }
1438 
1439 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1440   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1441   if (mlirOperationIsNull(firstOp)) {
1442     // Just insert at end.
1443     return PyInsertionPoint(block);
1444   }
1445 
1446   // Insert before first op.
1447   PyOperationRef firstOpRef = PyOperation::forOperation(
1448       block.getParentOperation()->getContext(), firstOp);
1449   return PyInsertionPoint{block, std::move(firstOpRef)};
1450 }
1451 
1452 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1453   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1454   if (mlirOperationIsNull(terminator))
1455     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1456   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1457       block.getParentOperation()->getContext(), terminator);
1458   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1459 }
1460 
1461 py::object PyInsertionPoint::contextEnter() {
1462   return PyThreadContextEntry::pushInsertionPoint(*this);
1463 }
1464 
1465 void PyInsertionPoint::contextExit(pybind11::object excType,
1466                                    pybind11::object excVal,
1467                                    pybind11::object excTb) {
1468   PyThreadContextEntry::popInsertionPoint(*this);
1469 }
1470 
1471 //------------------------------------------------------------------------------
1472 // PyAttribute.
1473 //------------------------------------------------------------------------------
1474 
1475 bool PyAttribute::operator==(const PyAttribute &other) {
1476   return mlirAttributeEqual(attr, other.attr);
1477 }
1478 
1479 py::object PyAttribute::getCapsule() {
1480   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1481 }
1482 
1483 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1484   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1485   if (mlirAttributeIsNull(rawAttr))
1486     throw py::error_already_set();
1487   return PyAttribute(
1488       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1489 }
1490 
1491 //------------------------------------------------------------------------------
1492 // PyNamedAttribute.
1493 //------------------------------------------------------------------------------
1494 
1495 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1496     : ownedName(new std::string(std::move(ownedName))) {
1497   namedAttr = mlirNamedAttributeGet(
1498       mlirIdentifierGet(mlirAttributeGetContext(attr),
1499                         toMlirStringRef(*this->ownedName)),
1500       attr);
1501 }
1502 
1503 //------------------------------------------------------------------------------
1504 // PyType.
1505 //------------------------------------------------------------------------------
1506 
1507 bool PyType::operator==(const PyType &other) {
1508   return mlirTypeEqual(type, other.type);
1509 }
1510 
1511 py::object PyType::getCapsule() {
1512   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1513 }
1514 
1515 PyType PyType::createFromCapsule(py::object capsule) {
1516   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1517   if (mlirTypeIsNull(rawType))
1518     throw py::error_already_set();
1519   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1520                 rawType);
1521 }
1522 
1523 //------------------------------------------------------------------------------
1524 // PyValue and subclases.
1525 //------------------------------------------------------------------------------
1526 
1527 pybind11::object PyValue::getCapsule() {
1528   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1529 }
1530 
1531 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1532   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1533   if (mlirValueIsNull(value))
1534     throw py::error_already_set();
1535   MlirOperation owner;
1536   if (mlirValueIsAOpResult(value))
1537     owner = mlirOpResultGetOwner(value);
1538   if (mlirValueIsABlockArgument(value))
1539     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1540   if (mlirOperationIsNull(owner))
1541     throw py::error_already_set();
1542   MlirContext ctx = mlirOperationGetContext(owner);
1543   PyOperationRef ownerRef =
1544       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1545   return PyValue(ownerRef, value);
1546 }
1547 
1548 //------------------------------------------------------------------------------
1549 // PySymbolTable.
1550 //------------------------------------------------------------------------------
1551 
1552 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1553     : operation(operation.getOperation().getRef()) {
1554   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1555   if (mlirSymbolTableIsNull(symbolTable)) {
1556     throw py::cast_error("Operation is not a Symbol Table.");
1557   }
1558 }
1559 
1560 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1561   operation->checkValid();
1562   MlirOperation symbol = mlirSymbolTableLookup(
1563       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1564   if (mlirOperationIsNull(symbol))
1565     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1566 
1567   return PyOperation::forOperation(operation->getContext(), symbol,
1568                                    operation.getObject())
1569       ->createOpView();
1570 }
1571 
1572 void PySymbolTable::erase(PyOperationBase &symbol) {
1573   operation->checkValid();
1574   symbol.getOperation().checkValid();
1575   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1576   // The operation is also erased, so we must invalidate it. There may be Python
1577   // references to this operation so we don't want to delete it from the list of
1578   // live operations here.
1579   symbol.getOperation().valid = false;
1580 }
1581 
1582 void PySymbolTable::dunderDel(const std::string &name) {
1583   py::object operation = dunderGetItem(name);
1584   erase(py::cast<PyOperationBase &>(operation));
1585 }
1586 
1587 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1588   operation->checkValid();
1589   symbol.getOperation().checkValid();
1590   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1591       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1592   if (mlirAttributeIsNull(symbolAttr))
1593     throw py::value_error("Expected operation to have a symbol name.");
1594   return PyAttribute(
1595       symbol.getOperation().getContext(),
1596       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1597 }
1598 
1599 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1600   // Op must already be a symbol.
1601   PyOperation &operation = symbol.getOperation();
1602   operation.checkValid();
1603   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1604   MlirAttribute existingNameAttr =
1605       mlirOperationGetAttributeByName(operation.get(), attrName);
1606   if (mlirAttributeIsNull(existingNameAttr))
1607     throw py::value_error("Expected operation to have a symbol name.");
1608   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1609 }
1610 
1611 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1612                                   const std::string &name) {
1613   // Op must already be a symbol.
1614   PyOperation &operation = symbol.getOperation();
1615   operation.checkValid();
1616   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1617   MlirAttribute existingNameAttr =
1618       mlirOperationGetAttributeByName(operation.get(), attrName);
1619   if (mlirAttributeIsNull(existingNameAttr))
1620     throw py::value_error("Expected operation to have a symbol name.");
1621   MlirAttribute newNameAttr =
1622       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1623   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1624 }
1625 
1626 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1627   PyOperation &operation = symbol.getOperation();
1628   operation.checkValid();
1629   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1630   MlirAttribute existingVisAttr =
1631       mlirOperationGetAttributeByName(operation.get(), attrName);
1632   if (mlirAttributeIsNull(existingVisAttr))
1633     throw py::value_error("Expected operation to have a symbol visibility.");
1634   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1635 }
1636 
1637 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1638                                   const std::string &visibility) {
1639   if (visibility != "public" && visibility != "private" &&
1640       visibility != "nested")
1641     throw py::value_error(
1642         "Expected visibility to be 'public', 'private' or 'nested'");
1643   PyOperation &operation = symbol.getOperation();
1644   operation.checkValid();
1645   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1646   MlirAttribute existingVisAttr =
1647       mlirOperationGetAttributeByName(operation.get(), attrName);
1648   if (mlirAttributeIsNull(existingVisAttr))
1649     throw py::value_error("Expected operation to have a symbol visibility.");
1650   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1651                                                toMlirStringRef(visibility));
1652   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1653 }
1654 
1655 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1656                                          const std::string &newSymbol,
1657                                          PyOperationBase &from) {
1658   PyOperation &fromOperation = from.getOperation();
1659   fromOperation.checkValid();
1660   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1661           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1662           from.getOperation())))
1663 
1664     throw py::value_error("Symbol rename failed");
1665 }
1666 
1667 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1668                                      bool allSymUsesVisible,
1669                                      py::object callback) {
1670   PyOperation &fromOperation = from.getOperation();
1671   fromOperation.checkValid();
1672   struct UserData {
1673     PyMlirContextRef context;
1674     py::object callback;
1675     bool gotException;
1676     std::string exceptionWhat;
1677     py::object exceptionType;
1678   };
1679   UserData userData{
1680       fromOperation.getContext(), std::move(callback), false, {}, {}};
1681   mlirSymbolTableWalkSymbolTables(
1682       fromOperation.get(), allSymUsesVisible,
1683       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1684         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1685         auto pyFoundOp =
1686             PyOperation::forOperation(calleeUserData->context, foundOp);
1687         if (calleeUserData->gotException)
1688           return;
1689         try {
1690           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1691         } catch (py::error_already_set &e) {
1692           calleeUserData->gotException = true;
1693           calleeUserData->exceptionWhat = e.what();
1694           calleeUserData->exceptionType = e.type();
1695         }
1696       },
1697       static_cast<void *>(&userData));
1698   if (userData.gotException) {
1699     std::string message("Exception raised in callback: ");
1700     message.append(userData.exceptionWhat);
1701     throw std::runtime_error(std::move(message));
1702   }
1703 }
1704 
1705 namespace {
1706 /// CRTP base class for Python MLIR values that subclass Value and should be
1707 /// castable from it. The value hierarchy is one level deep and is not supposed
1708 /// to accommodate other levels unless core MLIR changes.
1709 template <typename DerivedTy>
1710 class PyConcreteValue : public PyValue {
1711 public:
1712   // Derived classes must define statics for:
1713   //   IsAFunctionTy isaFunction
1714   //   const char *pyClassName
1715   // and redefine bindDerived.
1716   using ClassTy = py::class_<DerivedTy, PyValue>;
1717   using IsAFunctionTy = bool (*)(MlirValue);
1718 
1719   PyConcreteValue() = default;
1720   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1721       : PyValue(operationRef, value) {}
1722   PyConcreteValue(PyValue &orig)
1723       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1724 
1725   /// Attempts to cast the original value to the derived type and throws on
1726   /// type mismatches.
1727   static MlirValue castFrom(PyValue &orig) {
1728     if (!DerivedTy::isaFunction(orig.get())) {
1729       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1730       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1731                                              DerivedTy::pyClassName +
1732                                              " (from " + origRepr + ")");
1733     }
1734     return orig.get();
1735   }
1736 
1737   /// Binds the Python module objects to functions of this class.
1738   static void bind(py::module &m) {
1739     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1740     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1741     cls.def_static(
1742         "isinstance",
1743         [](PyValue &otherValue) -> bool {
1744           return DerivedTy::isaFunction(otherValue);
1745         },
1746         py::arg("other_value"));
1747     DerivedTy::bindDerived(cls);
1748   }
1749 
1750   /// Implemented by derived classes to add methods to the Python subclass.
1751   static void bindDerived(ClassTy &m) {}
1752 };
1753 
1754 /// Python wrapper for MlirBlockArgument.
1755 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1756 public:
1757   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1758   static constexpr const char *pyClassName = "BlockArgument";
1759   using PyConcreteValue::PyConcreteValue;
1760 
1761   static void bindDerived(ClassTy &c) {
1762     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1763       return PyBlock(self.getParentOperation(),
1764                      mlirBlockArgumentGetOwner(self.get()));
1765     });
1766     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1767       return mlirBlockArgumentGetArgNumber(self.get());
1768     });
1769     c.def(
1770         "set_type",
1771         [](PyBlockArgument &self, PyType type) {
1772           return mlirBlockArgumentSetType(self.get(), type);
1773         },
1774         py::arg("type"));
1775   }
1776 };
1777 
1778 /// Python wrapper for MlirOpResult.
1779 class PyOpResult : public PyConcreteValue<PyOpResult> {
1780 public:
1781   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1782   static constexpr const char *pyClassName = "OpResult";
1783   using PyConcreteValue::PyConcreteValue;
1784 
1785   static void bindDerived(ClassTy &c) {
1786     c.def_property_readonly("owner", [](PyOpResult &self) {
1787       assert(
1788           mlirOperationEqual(self.getParentOperation()->get(),
1789                              mlirOpResultGetOwner(self.get())) &&
1790           "expected the owner of the value in Python to match that in the IR");
1791       return self.getParentOperation().getObject();
1792     });
1793     c.def_property_readonly("result_number", [](PyOpResult &self) {
1794       return mlirOpResultGetResultNumber(self.get());
1795     });
1796   }
1797 };
1798 
1799 /// Returns the list of types of the values held by container.
1800 template <typename Container>
1801 static std::vector<PyType> getValueTypes(Container &container,
1802                                          PyMlirContextRef &context) {
1803   std::vector<PyType> result;
1804   result.reserve(container.getNumElements());
1805   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1806     result.push_back(
1807         PyType(context, mlirValueGetType(container.getElement(i).get())));
1808   }
1809   return result;
1810 }
1811 
1812 /// A list of block arguments. Internally, these are stored as consecutive
1813 /// elements, random access is cheap. The argument list is associated with the
1814 /// operation that contains the block (detached blocks are not allowed in
1815 /// Python bindings) and extends its lifetime.
1816 class PyBlockArgumentList
1817     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1818 public:
1819   static constexpr const char *pyClassName = "BlockArgumentList";
1820 
1821   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1822                       intptr_t startIndex = 0, intptr_t length = -1,
1823                       intptr_t step = 1)
1824       : Sliceable(startIndex,
1825                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1826                   step),
1827         operation(std::move(operation)), block(block) {}
1828 
1829   /// Returns the number of arguments in the list.
1830   intptr_t getNumElements() {
1831     operation->checkValid();
1832     return mlirBlockGetNumArguments(block);
1833   }
1834 
1835   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1836   PyBlockArgument getElement(intptr_t pos) {
1837     MlirValue argument = mlirBlockGetArgument(block, pos);
1838     return PyBlockArgument(operation, argument);
1839   }
1840 
1841   /// Returns a sublist of this list.
1842   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1843                             intptr_t step) {
1844     return PyBlockArgumentList(operation, block, startIndex, length, step);
1845   }
1846 
1847   static void bindDerived(ClassTy &c) {
1848     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1849       return getValueTypes(self, self.operation->getContext());
1850     });
1851   }
1852 
1853 private:
1854   PyOperationRef operation;
1855   MlirBlock block;
1856 };
1857 
1858 /// A list of operation operands. Internally, these are stored as consecutive
1859 /// elements, random access is cheap. The result list is associated with the
1860 /// operation whose results these are, and extends the lifetime of this
1861 /// operation.
1862 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1863 public:
1864   static constexpr const char *pyClassName = "OpOperandList";
1865 
1866   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1867                   intptr_t length = -1, intptr_t step = 1)
1868       : Sliceable(startIndex,
1869                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1870                                : length,
1871                   step),
1872         operation(operation) {}
1873 
1874   intptr_t getNumElements() {
1875     operation->checkValid();
1876     return mlirOperationGetNumOperands(operation->get());
1877   }
1878 
1879   PyValue getElement(intptr_t pos) {
1880     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1881     MlirOperation owner;
1882     if (mlirValueIsAOpResult(operand))
1883       owner = mlirOpResultGetOwner(operand);
1884     else if (mlirValueIsABlockArgument(operand))
1885       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1886     else
1887       assert(false && "Value must be an block arg or op result.");
1888     PyOperationRef pyOwner =
1889         PyOperation::forOperation(operation->getContext(), owner);
1890     return PyValue(pyOwner, operand);
1891   }
1892 
1893   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1894     return PyOpOperandList(operation, startIndex, length, step);
1895   }
1896 
1897   void dunderSetItem(intptr_t index, PyValue value) {
1898     index = wrapIndex(index);
1899     mlirOperationSetOperand(operation->get(), index, value.get());
1900   }
1901 
1902   static void bindDerived(ClassTy &c) {
1903     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1904   }
1905 
1906 private:
1907   PyOperationRef operation;
1908 };
1909 
1910 /// A list of operation results. Internally, these are stored as consecutive
1911 /// elements, random access is cheap. The result list is associated with the
1912 /// operation whose results these are, and extends the lifetime of this
1913 /// operation.
1914 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1915 public:
1916   static constexpr const char *pyClassName = "OpResultList";
1917 
1918   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1919                  intptr_t length = -1, intptr_t step = 1)
1920       : Sliceable(startIndex,
1921                   length == -1 ? mlirOperationGetNumResults(operation->get())
1922                                : length,
1923                   step),
1924         operation(operation) {}
1925 
1926   intptr_t getNumElements() {
1927     operation->checkValid();
1928     return mlirOperationGetNumResults(operation->get());
1929   }
1930 
1931   PyOpResult getElement(intptr_t index) {
1932     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1933     return PyOpResult(value);
1934   }
1935 
1936   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1937     return PyOpResultList(operation, startIndex, length, step);
1938   }
1939 
1940   static void bindDerived(ClassTy &c) {
1941     c.def_property_readonly("types", [](PyOpResultList &self) {
1942       return getValueTypes(self, self.operation->getContext());
1943     });
1944   }
1945 
1946 private:
1947   PyOperationRef operation;
1948 };
1949 
1950 /// A list of operation attributes. Can be indexed by name, producing
1951 /// attributes, or by index, producing named attributes.
1952 class PyOpAttributeMap {
1953 public:
1954   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1955 
1956   PyAttribute dunderGetItemNamed(const std::string &name) {
1957     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1958                                                          toMlirStringRef(name));
1959     if (mlirAttributeIsNull(attr)) {
1960       throw SetPyError(PyExc_KeyError,
1961                        "attempt to access a non-existent attribute");
1962     }
1963     return PyAttribute(operation->getContext(), attr);
1964   }
1965 
1966   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1967     if (index < 0 || index >= dunderLen()) {
1968       throw SetPyError(PyExc_IndexError,
1969                        "attempt to access out of bounds attribute");
1970     }
1971     MlirNamedAttribute namedAttr =
1972         mlirOperationGetAttribute(operation->get(), index);
1973     return PyNamedAttribute(
1974         namedAttr.attribute,
1975         std::string(mlirIdentifierStr(namedAttr.name).data,
1976                     mlirIdentifierStr(namedAttr.name).length));
1977   }
1978 
1979   void dunderSetItem(const std::string &name, PyAttribute attr) {
1980     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1981                                     attr);
1982   }
1983 
1984   void dunderDelItem(const std::string &name) {
1985     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1986                                                      toMlirStringRef(name));
1987     if (!removed)
1988       throw SetPyError(PyExc_KeyError,
1989                        "attempt to delete a non-existent attribute");
1990   }
1991 
1992   intptr_t dunderLen() {
1993     return mlirOperationGetNumAttributes(operation->get());
1994   }
1995 
1996   bool dunderContains(const std::string &name) {
1997     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1998         operation->get(), toMlirStringRef(name)));
1999   }
2000 
2001   static void bind(py::module &m) {
2002     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2003         .def("__contains__", &PyOpAttributeMap::dunderContains)
2004         .def("__len__", &PyOpAttributeMap::dunderLen)
2005         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2006         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2007         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2008         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2009   }
2010 
2011 private:
2012   PyOperationRef operation;
2013 };
2014 
2015 } // end namespace
2016 
2017 //------------------------------------------------------------------------------
2018 // Populates the core exports of the 'ir' submodule.
2019 //------------------------------------------------------------------------------
2020 
2021 void mlir::python::populateIRCore(py::module &m) {
2022   //----------------------------------------------------------------------------
2023   // Mapping of MlirContext.
2024   //----------------------------------------------------------------------------
2025   py::class_<PyMlirContext>(m, "Context", py::module_local())
2026       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2027       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2028       .def("_get_context_again",
2029            [](PyMlirContext &self) {
2030              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2031              return ref.releaseObject();
2032            })
2033       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2034       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2035       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2036                              &PyMlirContext::getCapsule)
2037       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2038       .def("__enter__", &PyMlirContext::contextEnter)
2039       .def("__exit__", &PyMlirContext::contextExit)
2040       .def_property_readonly_static(
2041           "current",
2042           [](py::object & /*class*/) {
2043             auto *context = PyThreadContextEntry::getDefaultContext();
2044             if (!context)
2045               throw SetPyError(PyExc_ValueError, "No current Context");
2046             return context;
2047           },
2048           "Gets the Context bound to the current thread or raises ValueError")
2049       .def_property_readonly(
2050           "dialects",
2051           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2052           "Gets a container for accessing dialects by name")
2053       .def_property_readonly(
2054           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2055           "Alias for 'dialect'")
2056       .def(
2057           "get_dialect_descriptor",
2058           [=](PyMlirContext &self, std::string &name) {
2059             MlirDialect dialect = mlirContextGetOrLoadDialect(
2060                 self.get(), {name.data(), name.size()});
2061             if (mlirDialectIsNull(dialect)) {
2062               throw SetPyError(PyExc_ValueError,
2063                                Twine("Dialect '") + name + "' not found");
2064             }
2065             return PyDialectDescriptor(self.getRef(), dialect);
2066           },
2067           py::arg("dialect_name"),
2068           "Gets or loads a dialect by name, returning its descriptor object")
2069       .def_property(
2070           "allow_unregistered_dialects",
2071           [](PyMlirContext &self) -> bool {
2072             return mlirContextGetAllowUnregisteredDialects(self.get());
2073           },
2074           [](PyMlirContext &self, bool value) {
2075             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2076           })
2077       .def(
2078           "enable_multithreading",
2079           [](PyMlirContext &self, bool enable) {
2080             mlirContextEnableMultithreading(self.get(), enable);
2081           },
2082           py::arg("enable"))
2083       .def(
2084           "is_registered_operation",
2085           [](PyMlirContext &self, std::string &name) {
2086             return mlirContextIsRegisteredOperation(
2087                 self.get(), MlirStringRef{name.data(), name.size()});
2088           },
2089           py::arg("operation_name"));
2090 
2091   //----------------------------------------------------------------------------
2092   // Mapping of PyDialectDescriptor
2093   //----------------------------------------------------------------------------
2094   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2095       .def_property_readonly("namespace",
2096                              [](PyDialectDescriptor &self) {
2097                                MlirStringRef ns =
2098                                    mlirDialectGetNamespace(self.get());
2099                                return py::str(ns.data, ns.length);
2100                              })
2101       .def("__repr__", [](PyDialectDescriptor &self) {
2102         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2103         std::string repr("<DialectDescriptor ");
2104         repr.append(ns.data, ns.length);
2105         repr.append(">");
2106         return repr;
2107       });
2108 
2109   //----------------------------------------------------------------------------
2110   // Mapping of PyDialects
2111   //----------------------------------------------------------------------------
2112   py::class_<PyDialects>(m, "Dialects", py::module_local())
2113       .def("__getitem__",
2114            [=](PyDialects &self, std::string keyName) {
2115              MlirDialect dialect =
2116                  self.getDialectForKey(keyName, /*attrError=*/false);
2117              py::object descriptor =
2118                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2119              return createCustomDialectWrapper(keyName, std::move(descriptor));
2120            })
2121       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2122         MlirDialect dialect =
2123             self.getDialectForKey(attrName, /*attrError=*/true);
2124         py::object descriptor =
2125             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2126         return createCustomDialectWrapper(attrName, std::move(descriptor));
2127       });
2128 
2129   //----------------------------------------------------------------------------
2130   // Mapping of PyDialect
2131   //----------------------------------------------------------------------------
2132   py::class_<PyDialect>(m, "Dialect", py::module_local())
2133       .def(py::init<py::object>(), py::arg("descriptor"))
2134       .def_property_readonly(
2135           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2136       .def("__repr__", [](py::object self) {
2137         auto clazz = self.attr("__class__");
2138         return py::str("<Dialect ") +
2139                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2140                clazz.attr("__module__") + py::str(".") +
2141                clazz.attr("__name__") + py::str(")>");
2142       });
2143 
2144   //----------------------------------------------------------------------------
2145   // Mapping of Location
2146   //----------------------------------------------------------------------------
2147   py::class_<PyLocation>(m, "Location", py::module_local())
2148       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2149       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2150       .def("__enter__", &PyLocation::contextEnter)
2151       .def("__exit__", &PyLocation::contextExit)
2152       .def("__eq__",
2153            [](PyLocation &self, PyLocation &other) -> bool {
2154              return mlirLocationEqual(self, other);
2155            })
2156       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2157       .def_property_readonly_static(
2158           "current",
2159           [](py::object & /*class*/) {
2160             auto *loc = PyThreadContextEntry::getDefaultLocation();
2161             if (!loc)
2162               throw SetPyError(PyExc_ValueError, "No current Location");
2163             return loc;
2164           },
2165           "Gets the Location bound to the current thread or raises ValueError")
2166       .def_static(
2167           "unknown",
2168           [](DefaultingPyMlirContext context) {
2169             return PyLocation(context->getRef(),
2170                               mlirLocationUnknownGet(context->get()));
2171           },
2172           py::arg("context") = py::none(),
2173           "Gets a Location representing an unknown location")
2174       .def_static(
2175           "callsite",
2176           [](PyLocation callee, const std::vector<PyLocation> &frames,
2177              DefaultingPyMlirContext context) {
2178             if (frames.empty())
2179               throw py::value_error("No caller frames provided");
2180             MlirLocation caller = frames.back().get();
2181             for (const PyLocation &frame :
2182                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2183               caller = mlirLocationCallSiteGet(frame.get(), caller);
2184             return PyLocation(context->getRef(),
2185                               mlirLocationCallSiteGet(callee.get(), caller));
2186           },
2187           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2188           kContextGetCallSiteLocationDocstring)
2189       .def_static(
2190           "file",
2191           [](std::string filename, int line, int col,
2192              DefaultingPyMlirContext context) {
2193             return PyLocation(
2194                 context->getRef(),
2195                 mlirLocationFileLineColGet(
2196                     context->get(), toMlirStringRef(filename), line, col));
2197           },
2198           py::arg("filename"), py::arg("line"), py::arg("col"),
2199           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2200       .def_static(
2201           "name",
2202           [](std::string name, llvm::Optional<PyLocation> childLoc,
2203              DefaultingPyMlirContext context) {
2204             return PyLocation(
2205                 context->getRef(),
2206                 mlirLocationNameGet(
2207                     context->get(), toMlirStringRef(name),
2208                     childLoc ? childLoc->get()
2209                              : mlirLocationUnknownGet(context->get())));
2210           },
2211           py::arg("name"), py::arg("childLoc") = py::none(),
2212           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2213       .def_property_readonly(
2214           "context",
2215           [](PyLocation &self) { return self.getContext().getObject(); },
2216           "Context that owns the Location")
2217       .def("__repr__", [](PyLocation &self) {
2218         PyPrintAccumulator printAccum;
2219         mlirLocationPrint(self, printAccum.getCallback(),
2220                           printAccum.getUserData());
2221         return printAccum.join();
2222       });
2223 
2224   //----------------------------------------------------------------------------
2225   // Mapping of Module
2226   //----------------------------------------------------------------------------
2227   py::class_<PyModule>(m, "Module", py::module_local())
2228       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2229       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2230       .def_static(
2231           "parse",
2232           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2233             MlirModule module = mlirModuleCreateParse(
2234                 context->get(), toMlirStringRef(moduleAsm));
2235             // TODO: Rework error reporting once diagnostic engine is exposed
2236             // in C API.
2237             if (mlirModuleIsNull(module)) {
2238               throw SetPyError(
2239                   PyExc_ValueError,
2240                   "Unable to parse module assembly (see diagnostics)");
2241             }
2242             return PyModule::forModule(module).releaseObject();
2243           },
2244           py::arg("asm"), py::arg("context") = py::none(),
2245           kModuleParseDocstring)
2246       .def_static(
2247           "create",
2248           [](DefaultingPyLocation loc) {
2249             MlirModule module = mlirModuleCreateEmpty(loc);
2250             return PyModule::forModule(module).releaseObject();
2251           },
2252           py::arg("loc") = py::none(), "Creates an empty module")
2253       .def_property_readonly(
2254           "context",
2255           [](PyModule &self) { return self.getContext().getObject(); },
2256           "Context that created the Module")
2257       .def_property_readonly(
2258           "operation",
2259           [](PyModule &self) {
2260             return PyOperation::forOperation(self.getContext(),
2261                                              mlirModuleGetOperation(self.get()),
2262                                              self.getRef().releaseObject())
2263                 .releaseObject();
2264           },
2265           "Accesses the module as an operation")
2266       .def_property_readonly(
2267           "body",
2268           [](PyModule &self) {
2269             PyOperationRef module_op = PyOperation::forOperation(
2270                 self.getContext(), mlirModuleGetOperation(self.get()),
2271                 self.getRef().releaseObject());
2272             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2273             return returnBlock;
2274           },
2275           "Return the block for this module")
2276       .def(
2277           "dump",
2278           [](PyModule &self) {
2279             mlirOperationDump(mlirModuleGetOperation(self.get()));
2280           },
2281           kDumpDocstring)
2282       .def(
2283           "__str__",
2284           [](py::object self) {
2285             // Defer to the operation's __str__.
2286             return self.attr("operation").attr("__str__")();
2287           },
2288           kOperationStrDunderDocstring);
2289 
2290   //----------------------------------------------------------------------------
2291   // Mapping of Operation.
2292   //----------------------------------------------------------------------------
2293   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2294       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2295                              [](PyOperationBase &self) {
2296                                return self.getOperation().getCapsule();
2297                              })
2298       .def("__eq__",
2299            [](PyOperationBase &self, PyOperationBase &other) {
2300              return &self.getOperation() == &other.getOperation();
2301            })
2302       .def("__eq__",
2303            [](PyOperationBase &self, py::object other) { return false; })
2304       .def("__hash__",
2305            [](PyOperationBase &self) {
2306              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2307            })
2308       .def_property_readonly("attributes",
2309                              [](PyOperationBase &self) {
2310                                return PyOpAttributeMap(
2311                                    self.getOperation().getRef());
2312                              })
2313       .def_property_readonly("operands",
2314                              [](PyOperationBase &self) {
2315                                return PyOpOperandList(
2316                                    self.getOperation().getRef());
2317                              })
2318       .def_property_readonly("regions",
2319                              [](PyOperationBase &self) {
2320                                return PyRegionList(
2321                                    self.getOperation().getRef());
2322                              })
2323       .def_property_readonly(
2324           "results",
2325           [](PyOperationBase &self) {
2326             return PyOpResultList(self.getOperation().getRef());
2327           },
2328           "Returns the list of Operation results.")
2329       .def_property_readonly(
2330           "result",
2331           [](PyOperationBase &self) {
2332             auto &operation = self.getOperation();
2333             auto numResults = mlirOperationGetNumResults(operation);
2334             if (numResults != 1) {
2335               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2336               throw SetPyError(
2337                   PyExc_ValueError,
2338                   Twine("Cannot call .result on operation ") +
2339                       StringRef(name.data, name.length) + " which has " +
2340                       Twine(numResults) +
2341                       " results (it is only valid for operations with a "
2342                       "single result)");
2343             }
2344             return PyOpResult(operation.getRef(),
2345                               mlirOperationGetResult(operation, 0));
2346           },
2347           "Shortcut to get an op result if it has only one (throws an error "
2348           "otherwise).")
2349       .def_property_readonly(
2350           "location",
2351           [](PyOperationBase &self) {
2352             PyOperation &operation = self.getOperation();
2353             return PyLocation(operation.getContext(),
2354                               mlirOperationGetLocation(operation.get()));
2355           },
2356           "Returns the source location the operation was defined or derived "
2357           "from.")
2358       .def(
2359           "__str__",
2360           [](PyOperationBase &self) {
2361             return self.getAsm(/*binary=*/false,
2362                                /*largeElementsLimit=*/llvm::None,
2363                                /*enableDebugInfo=*/false,
2364                                /*prettyDebugInfo=*/false,
2365                                /*printGenericOpForm=*/false,
2366                                /*useLocalScope=*/false,
2367                                /*assumeVerified=*/false);
2368           },
2369           "Returns the assembly form of the operation.")
2370       .def("print", &PyOperationBase::print,
2371            // Careful: Lots of arguments must match up with print method.
2372            py::arg("file") = py::none(), py::arg("binary") = false,
2373            py::arg("large_elements_limit") = py::none(),
2374            py::arg("enable_debug_info") = false,
2375            py::arg("pretty_debug_info") = false,
2376            py::arg("print_generic_op_form") = false,
2377            py::arg("use_local_scope") = false,
2378            py::arg("assume_verified") = false, kOperationPrintDocstring)
2379       .def("get_asm", &PyOperationBase::getAsm,
2380            // Careful: Lots of arguments must match up with get_asm method.
2381            py::arg("binary") = false,
2382            py::arg("large_elements_limit") = py::none(),
2383            py::arg("enable_debug_info") = false,
2384            py::arg("pretty_debug_info") = false,
2385            py::arg("print_generic_op_form") = false,
2386            py::arg("use_local_scope") = false,
2387            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2388       .def(
2389           "verify",
2390           [](PyOperationBase &self) {
2391             return mlirOperationVerify(self.getOperation());
2392           },
2393           "Verify the operation and return true if it passes, false if it "
2394           "fails.")
2395       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2396            "Puts self immediately after the other operation in its parent "
2397            "block.")
2398       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2399            "Puts self immediately before the other operation in its parent "
2400            "block.")
2401       .def(
2402           "detach_from_parent",
2403           [](PyOperationBase &self) {
2404             PyOperation &operation = self.getOperation();
2405             operation.checkValid();
2406             if (!operation.isAttached())
2407               throw py::value_error("Detached operation has no parent.");
2408 
2409             operation.detachFromParent();
2410             return operation.createOpView();
2411           },
2412           "Detaches the operation from its parent block.");
2413 
2414   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2415       .def_static("create", &PyOperation::create, py::arg("name"),
2416                   py::arg("results") = py::none(),
2417                   py::arg("operands") = py::none(),
2418                   py::arg("attributes") = py::none(),
2419                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2420                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2421                   kOperationCreateDocstring)
2422       .def_property_readonly("parent",
2423                              [](PyOperation &self) -> py::object {
2424                                auto parent = self.getParentOperation();
2425                                if (parent)
2426                                  return parent->getObject();
2427                                return py::none();
2428                              })
2429       .def("erase", &PyOperation::erase)
2430       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2431                              &PyOperation::getCapsule)
2432       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2433       .def_property_readonly("name",
2434                              [](PyOperation &self) {
2435                                self.checkValid();
2436                                MlirOperation operation = self.get();
2437                                MlirStringRef name = mlirIdentifierStr(
2438                                    mlirOperationGetName(operation));
2439                                return py::str(name.data, name.length);
2440                              })
2441       .def_property_readonly(
2442           "context",
2443           [](PyOperation &self) {
2444             self.checkValid();
2445             return self.getContext().getObject();
2446           },
2447           "Context that owns the Operation")
2448       .def_property_readonly("opview", &PyOperation::createOpView);
2449 
2450   auto opViewClass =
2451       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2452           .def(py::init<py::object>(), py::arg("operation"))
2453           .def_property_readonly("operation", &PyOpView::getOperationObject)
2454           .def_property_readonly(
2455               "context",
2456               [](PyOpView &self) {
2457                 return self.getOperation().getContext().getObject();
2458               },
2459               "Context that owns the Operation")
2460           .def("__str__", [](PyOpView &self) {
2461             return py::str(self.getOperationObject());
2462           });
2463   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2464   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2465   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2466   opViewClass.attr("build_generic") = classmethod(
2467       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2468       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2469       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2470       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2471       "Builds a specific, generated OpView based on class level attributes.");
2472 
2473   //----------------------------------------------------------------------------
2474   // Mapping of PyRegion.
2475   //----------------------------------------------------------------------------
2476   py::class_<PyRegion>(m, "Region", py::module_local())
2477       .def_property_readonly(
2478           "blocks",
2479           [](PyRegion &self) {
2480             return PyBlockList(self.getParentOperation(), self.get());
2481           },
2482           "Returns a forward-optimized sequence of blocks.")
2483       .def_property_readonly(
2484           "owner",
2485           [](PyRegion &self) {
2486             return self.getParentOperation()->createOpView();
2487           },
2488           "Returns the operation owning this region.")
2489       .def(
2490           "__iter__",
2491           [](PyRegion &self) {
2492             self.checkValid();
2493             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2494             return PyBlockIterator(self.getParentOperation(), firstBlock);
2495           },
2496           "Iterates over blocks in the region.")
2497       .def("__eq__",
2498            [](PyRegion &self, PyRegion &other) {
2499              return self.get().ptr == other.get().ptr;
2500            })
2501       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2502 
2503   //----------------------------------------------------------------------------
2504   // Mapping of PyBlock.
2505   //----------------------------------------------------------------------------
2506   py::class_<PyBlock>(m, "Block", py::module_local())
2507       .def_property_readonly(
2508           "owner",
2509           [](PyBlock &self) {
2510             return self.getParentOperation()->createOpView();
2511           },
2512           "Returns the owning operation of this block.")
2513       .def_property_readonly(
2514           "region",
2515           [](PyBlock &self) {
2516             MlirRegion region = mlirBlockGetParentRegion(self.get());
2517             return PyRegion(self.getParentOperation(), region);
2518           },
2519           "Returns the owning region of this block.")
2520       .def_property_readonly(
2521           "arguments",
2522           [](PyBlock &self) {
2523             return PyBlockArgumentList(self.getParentOperation(), self.get());
2524           },
2525           "Returns a list of block arguments.")
2526       .def_property_readonly(
2527           "operations",
2528           [](PyBlock &self) {
2529             return PyOperationList(self.getParentOperation(), self.get());
2530           },
2531           "Returns a forward-optimized sequence of operations.")
2532       .def_static(
2533           "create_at_start",
2534           [](PyRegion &parent, py::list pyArgTypes) {
2535             parent.checkValid();
2536             llvm::SmallVector<MlirType, 4> argTypes;
2537             argTypes.reserve(pyArgTypes.size());
2538             for (auto &pyArg : pyArgTypes) {
2539               argTypes.push_back(pyArg.cast<PyType &>());
2540             }
2541 
2542             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2543             mlirRegionInsertOwnedBlock(parent, 0, block);
2544             return PyBlock(parent.getParentOperation(), block);
2545           },
2546           py::arg("parent"), py::arg("arg_types") = py::list(),
2547           "Creates and returns a new Block at the beginning of the given "
2548           "region (with given argument types).")
2549       .def(
2550           "create_before",
2551           [](PyBlock &self, py::args pyArgTypes) {
2552             self.checkValid();
2553             llvm::SmallVector<MlirType, 4> argTypes;
2554             argTypes.reserve(pyArgTypes.size());
2555             for (auto &pyArg : pyArgTypes) {
2556               argTypes.push_back(pyArg.cast<PyType &>());
2557             }
2558 
2559             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2560             MlirRegion region = mlirBlockGetParentRegion(self.get());
2561             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2562             return PyBlock(self.getParentOperation(), block);
2563           },
2564           "Creates and returns a new Block before this block "
2565           "(with given argument types).")
2566       .def(
2567           "create_after",
2568           [](PyBlock &self, py::args pyArgTypes) {
2569             self.checkValid();
2570             llvm::SmallVector<MlirType, 4> argTypes;
2571             argTypes.reserve(pyArgTypes.size());
2572             for (auto &pyArg : pyArgTypes) {
2573               argTypes.push_back(pyArg.cast<PyType &>());
2574             }
2575 
2576             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2577             MlirRegion region = mlirBlockGetParentRegion(self.get());
2578             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2579             return PyBlock(self.getParentOperation(), block);
2580           },
2581           "Creates and returns a new Block after this block "
2582           "(with given argument types).")
2583       .def(
2584           "__iter__",
2585           [](PyBlock &self) {
2586             self.checkValid();
2587             MlirOperation firstOperation =
2588                 mlirBlockGetFirstOperation(self.get());
2589             return PyOperationIterator(self.getParentOperation(),
2590                                        firstOperation);
2591           },
2592           "Iterates over operations in the block.")
2593       .def("__eq__",
2594            [](PyBlock &self, PyBlock &other) {
2595              return self.get().ptr == other.get().ptr;
2596            })
2597       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2598       .def(
2599           "__str__",
2600           [](PyBlock &self) {
2601             self.checkValid();
2602             PyPrintAccumulator printAccum;
2603             mlirBlockPrint(self.get(), printAccum.getCallback(),
2604                            printAccum.getUserData());
2605             return printAccum.join();
2606           },
2607           "Returns the assembly form of the block.")
2608       .def(
2609           "append",
2610           [](PyBlock &self, PyOperationBase &operation) {
2611             if (operation.getOperation().isAttached())
2612               operation.getOperation().detachFromParent();
2613 
2614             MlirOperation mlirOperation = operation.getOperation().get();
2615             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2616             operation.getOperation().setAttached(
2617                 self.getParentOperation().getObject());
2618           },
2619           py::arg("operation"),
2620           "Appends an operation to this block. If the operation is currently "
2621           "in another block, it will be moved.");
2622 
2623   //----------------------------------------------------------------------------
2624   // Mapping of PyInsertionPoint.
2625   //----------------------------------------------------------------------------
2626 
2627   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2628       .def(py::init<PyBlock &>(), py::arg("block"),
2629            "Inserts after the last operation but still inside the block.")
2630       .def("__enter__", &PyInsertionPoint::contextEnter)
2631       .def("__exit__", &PyInsertionPoint::contextExit)
2632       .def_property_readonly_static(
2633           "current",
2634           [](py::object & /*class*/) {
2635             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2636             if (!ip)
2637               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2638             return ip;
2639           },
2640           "Gets the InsertionPoint bound to the current thread or raises "
2641           "ValueError if none has been set")
2642       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2643            "Inserts before a referenced operation.")
2644       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2645                   py::arg("block"), "Inserts at the beginning of the block.")
2646       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2647                   py::arg("block"), "Inserts before the block terminator.")
2648       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2649            "Inserts an operation.")
2650       .def_property_readonly(
2651           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2652           "Returns the block that this InsertionPoint points to.");
2653 
2654   //----------------------------------------------------------------------------
2655   // Mapping of PyAttribute.
2656   //----------------------------------------------------------------------------
2657   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2658       // Delegate to the PyAttribute copy constructor, which will also lifetime
2659       // extend the backing context which owns the MlirAttribute.
2660       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2661            "Casts the passed attribute to the generic Attribute")
2662       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2663                              &PyAttribute::getCapsule)
2664       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2665       .def_static(
2666           "parse",
2667           [](std::string attrSpec, DefaultingPyMlirContext context) {
2668             MlirAttribute type = mlirAttributeParseGet(
2669                 context->get(), toMlirStringRef(attrSpec));
2670             // TODO: Rework error reporting once diagnostic engine is exposed
2671             // in C API.
2672             if (mlirAttributeIsNull(type)) {
2673               throw SetPyError(PyExc_ValueError,
2674                                Twine("Unable to parse attribute: '") +
2675                                    attrSpec + "'");
2676             }
2677             return PyAttribute(context->getRef(), type);
2678           },
2679           py::arg("asm"), py::arg("context") = py::none(),
2680           "Parses an attribute from an assembly form")
2681       .def_property_readonly(
2682           "context",
2683           [](PyAttribute &self) { return self.getContext().getObject(); },
2684           "Context that owns the Attribute")
2685       .def_property_readonly("type",
2686                              [](PyAttribute &self) {
2687                                return PyType(self.getContext()->getRef(),
2688                                              mlirAttributeGetType(self));
2689                              })
2690       .def(
2691           "get_named",
2692           [](PyAttribute &self, std::string name) {
2693             return PyNamedAttribute(self, std::move(name));
2694           },
2695           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2696       .def("__eq__",
2697            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2698       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2699       .def("__hash__",
2700            [](PyAttribute &self) {
2701              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2702            })
2703       .def(
2704           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2705           kDumpDocstring)
2706       .def(
2707           "__str__",
2708           [](PyAttribute &self) {
2709             PyPrintAccumulator printAccum;
2710             mlirAttributePrint(self, printAccum.getCallback(),
2711                                printAccum.getUserData());
2712             return printAccum.join();
2713           },
2714           "Returns the assembly form of the Attribute.")
2715       .def("__repr__", [](PyAttribute &self) {
2716         // Generally, assembly formats are not printed for __repr__ because
2717         // this can cause exceptionally long debug output and exceptions.
2718         // However, attribute values are generally considered useful and are
2719         // printed. This may need to be re-evaluated if debug dumps end up
2720         // being excessive.
2721         PyPrintAccumulator printAccum;
2722         printAccum.parts.append("Attribute(");
2723         mlirAttributePrint(self, printAccum.getCallback(),
2724                            printAccum.getUserData());
2725         printAccum.parts.append(")");
2726         return printAccum.join();
2727       });
2728 
2729   //----------------------------------------------------------------------------
2730   // Mapping of PyNamedAttribute
2731   //----------------------------------------------------------------------------
2732   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2733       .def("__repr__",
2734            [](PyNamedAttribute &self) {
2735              PyPrintAccumulator printAccum;
2736              printAccum.parts.append("NamedAttribute(");
2737              printAccum.parts.append(
2738                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2739                          mlirIdentifierStr(self.namedAttr.name).length));
2740              printAccum.parts.append("=");
2741              mlirAttributePrint(self.namedAttr.attribute,
2742                                 printAccum.getCallback(),
2743                                 printAccum.getUserData());
2744              printAccum.parts.append(")");
2745              return printAccum.join();
2746            })
2747       .def_property_readonly(
2748           "name",
2749           [](PyNamedAttribute &self) {
2750             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2751                            mlirIdentifierStr(self.namedAttr.name).length);
2752           },
2753           "The name of the NamedAttribute binding")
2754       .def_property_readonly(
2755           "attr",
2756           [](PyNamedAttribute &self) {
2757             // TODO: When named attribute is removed/refactored, also remove
2758             // this constructor (it does an inefficient table lookup).
2759             auto contextRef = PyMlirContext::forContext(
2760                 mlirAttributeGetContext(self.namedAttr.attribute));
2761             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2762           },
2763           py::keep_alive<0, 1>(),
2764           "The underlying generic attribute of the NamedAttribute binding");
2765 
2766   //----------------------------------------------------------------------------
2767   // Mapping of PyType.
2768   //----------------------------------------------------------------------------
2769   py::class_<PyType>(m, "Type", py::module_local())
2770       // Delegate to the PyType copy constructor, which will also lifetime
2771       // extend the backing context which owns the MlirType.
2772       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2773            "Casts the passed type to the generic Type")
2774       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2775       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2776       .def_static(
2777           "parse",
2778           [](std::string typeSpec, DefaultingPyMlirContext context) {
2779             MlirType type =
2780                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2781             // TODO: Rework error reporting once diagnostic engine is exposed
2782             // in C API.
2783             if (mlirTypeIsNull(type)) {
2784               throw SetPyError(PyExc_ValueError,
2785                                Twine("Unable to parse type: '") + typeSpec +
2786                                    "'");
2787             }
2788             return PyType(context->getRef(), type);
2789           },
2790           py::arg("asm"), py::arg("context") = py::none(),
2791           kContextParseTypeDocstring)
2792       .def_property_readonly(
2793           "context", [](PyType &self) { return self.getContext().getObject(); },
2794           "Context that owns the Type")
2795       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2796       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2797       .def("__hash__",
2798            [](PyType &self) {
2799              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2800            })
2801       .def(
2802           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2803       .def(
2804           "__str__",
2805           [](PyType &self) {
2806             PyPrintAccumulator printAccum;
2807             mlirTypePrint(self, printAccum.getCallback(),
2808                           printAccum.getUserData());
2809             return printAccum.join();
2810           },
2811           "Returns the assembly form of the type.")
2812       .def("__repr__", [](PyType &self) {
2813         // Generally, assembly formats are not printed for __repr__ because
2814         // this can cause exceptionally long debug output and exceptions.
2815         // However, types are an exception as they typically have compact
2816         // assembly forms and printing them is useful.
2817         PyPrintAccumulator printAccum;
2818         printAccum.parts.append("Type(");
2819         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2820         printAccum.parts.append(")");
2821         return printAccum.join();
2822       });
2823 
2824   //----------------------------------------------------------------------------
2825   // Mapping of Value.
2826   //----------------------------------------------------------------------------
2827   py::class_<PyValue>(m, "Value", py::module_local())
2828       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2829       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2830       .def_property_readonly(
2831           "context",
2832           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2833           "Context in which the value lives.")
2834       .def(
2835           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2836           kDumpDocstring)
2837       .def_property_readonly(
2838           "owner",
2839           [](PyValue &self) {
2840             assert(mlirOperationEqual(self.getParentOperation()->get(),
2841                                       mlirOpResultGetOwner(self.get())) &&
2842                    "expected the owner of the value in Python to match that in "
2843                    "the IR");
2844             return self.getParentOperation().getObject();
2845           })
2846       .def("__eq__",
2847            [](PyValue &self, PyValue &other) {
2848              return self.get().ptr == other.get().ptr;
2849            })
2850       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2851       .def("__hash__",
2852            [](PyValue &self) {
2853              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2854            })
2855       .def(
2856           "__str__",
2857           [](PyValue &self) {
2858             PyPrintAccumulator printAccum;
2859             printAccum.parts.append("Value(");
2860             mlirValuePrint(self.get(), printAccum.getCallback(),
2861                            printAccum.getUserData());
2862             printAccum.parts.append(")");
2863             return printAccum.join();
2864           },
2865           kValueDunderStrDocstring)
2866       .def_property_readonly("type", [](PyValue &self) {
2867         return PyType(self.getParentOperation()->getContext(),
2868                       mlirValueGetType(self.get()));
2869       });
2870   PyBlockArgument::bind(m);
2871   PyOpResult::bind(m);
2872 
2873   //----------------------------------------------------------------------------
2874   // Mapping of SymbolTable.
2875   //----------------------------------------------------------------------------
2876   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2877       .def(py::init<PyOperationBase &>())
2878       .def("__getitem__", &PySymbolTable::dunderGetItem)
2879       .def("insert", &PySymbolTable::insert, py::arg("operation"))
2880       .def("erase", &PySymbolTable::erase, py::arg("operation"))
2881       .def("__delitem__", &PySymbolTable::dunderDel)
2882       .def("__contains__",
2883            [](PySymbolTable &table, const std::string &name) {
2884              return !mlirOperationIsNull(mlirSymbolTableLookup(
2885                  table, mlirStringRefCreate(name.data(), name.length())));
2886            })
2887       // Static helpers.
2888       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
2889                   py::arg("symbol"), py::arg("name"))
2890       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
2891                   py::arg("symbol"))
2892       .def_static("get_visibility", &PySymbolTable::getVisibility,
2893                   py::arg("symbol"))
2894       .def_static("set_visibility", &PySymbolTable::setVisibility,
2895                   py::arg("symbol"), py::arg("visibility"))
2896       .def_static("replace_all_symbol_uses",
2897                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
2898                   py::arg("new_symbol"), py::arg("from_op"))
2899       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
2900                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
2901                   py::arg("callback"));
2902 
2903   // Container bindings.
2904   PyBlockArgumentList::bind(m);
2905   PyBlockIterator::bind(m);
2906   PyBlockList::bind(m);
2907   PyOperationIterator::bind(m);
2908   PyOperationList::bind(m);
2909   PyOpAttributeMap::bind(m);
2910   PyOpOperandList::bind(m);
2911   PyOpResultList::bind(m);
2912   PyRegionIterator::bind(m);
2913   PyRegionList::bind(m);
2914 
2915   // Debug bindings.
2916   PyGlobalDebugFlag::bind(m);
2917 }
2918