1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetFusedLocationDocstring[] =
51     R"(Gets a Location representing a fused location with optional metadata)";
52 
53 static const char kContextGetNameLocationDocString[] =
54     R"(Gets a Location representing a named location with optional child location)";
55 
56 static const char kModuleParseDocstring[] =
57     R"(Parses a module's assembly format from a string.
58 
59 Returns a new MlirModule or raises a ValueError if the parsing fails.
60 
61 See also: https://mlir.llvm.org/docs/LangRef/
62 )";
63 
64 static const char kOperationCreateDocstring[] =
65     R"(Creates a new operation.
66 
67 Args:
68   name: Operation name (e.g. "dialect.operation").
69   results: Sequence of Type representing op result types.
70   attributes: Dict of str:Attribute.
71   successors: List of Block for the operation's successors.
72   regions: Number of regions to create.
73   location: A Location object (defaults to resolve from context manager).
74   ip: An InsertionPoint (defaults to resolve from context manager or set to
75     False to disable insertion, even with an insertion point set in the
76     context manager).
77 Returns:
78   A new "detached" Operation object. Detached operations can be added
79   to blocks, which causes them to become "attached."
80 )";
81 
82 static const char kOperationPrintDocstring[] =
83     R"(Prints the assembly form of the operation to a file like object.
84 
85 Args:
86   file: The file like object to write to. Defaults to sys.stdout.
87   binary: Whether to write bytes (True) or str (False). Defaults to False.
88   large_elements_limit: Whether to elide elements attributes above this
89     number of elements. Defaults to None (no limit).
90   enable_debug_info: Whether to print debug/location information. Defaults
91     to False.
92   pretty_debug_info: Whether to format debug information for easier reading
93     by a human (warning: the result is unparseable).
94   print_generic_op_form: Whether to print the generic assembly forms of all
95     ops. Defaults to False.
96   use_local_Scope: Whether to print in a way that is more optimized for
97     multi-threaded access but may not be consistent with how the overall
98     module prints.
99   assume_verified: By default, if not printing generic form, the verifier
100     will be run and if it fails, generic form will be printed with a comment
101     about failed verification. While a reasonable default for interactive use,
102     for systematic use, it is often better for the caller to verify explicitly
103     and report failures in a more robust fashion. Set this to True if doing this
104     in order to avoid running a redundant verification. If the IR is actually
105     invalid, behavior is undefined.
106 )";
107 
108 static const char kOperationGetAsmDocstring[] =
109     R"(Gets the assembly form of the operation with all options available.
110 
111 Args:
112   binary: Whether to return a bytes (True) or str (False) object. Defaults to
113     False.
114   ... others ...: See the print() method for common keyword arguments for
115     configuring the printout.
116 Returns:
117   Either a bytes or str object, depending on the setting of the 'binary'
118   argument.
119 )";
120 
121 static const char kOperationStrDunderDocstring[] =
122     R"(Gets the assembly form of the operation with default options.
123 
124 If more advanced control over the assembly formatting or I/O options is needed,
125 use the dedicated print or get_asm method, which supports keyword arguments to
126 customize behavior.
127 )";
128 
129 static const char kDumpDocstring[] =
130     R"(Dumps a debug representation of the object to stderr.)";
131 
132 static const char kAppendBlockDocstring[] =
133     R"(Appends a new block, with argument types as positional args.
134 
135 Returns:
136   The created block.
137 )";
138 
139 static const char kValueDunderStrDocstring[] =
140     R"(Returns the string form of the value.
141 
142 If the value is a block argument, this is the assembly form of its type and the
143 position in the argument list. If the value is an operation result, this is
144 equivalent to printing the operation that produced it.
145 )";
146 
147 //------------------------------------------------------------------------------
148 // Utilities.
149 //------------------------------------------------------------------------------
150 
151 /// Helper for creating an @classmethod.
152 template <class Func, typename... Args>
153 py::object classmethod(Func f, Args... args) {
154   py::object cf = py::cpp_function(f, args...);
155   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
156 }
157 
158 static py::object
159 createCustomDialectWrapper(const std::string &dialectNamespace,
160                            py::object dialectDescriptor) {
161   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
162   if (!dialectClass) {
163     // Use the base class.
164     return py::cast(PyDialect(std::move(dialectDescriptor)));
165   }
166 
167   // Create the custom implementation.
168   return (*dialectClass)(std::move(dialectDescriptor));
169 }
170 
171 static MlirStringRef toMlirStringRef(const std::string &s) {
172   return mlirStringRefCreate(s.data(), s.size());
173 }
174 
175 /// Wrapper for the global LLVM debugging flag.
176 struct PyGlobalDebugFlag {
177   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
178 
179   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
180 
181   static void bind(py::module &m) {
182     // Debug flags.
183     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
184         .def_property_static("flag", &PyGlobalDebugFlag::get,
185                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
186   }
187 };
188 
189 //------------------------------------------------------------------------------
190 // Collections.
191 //------------------------------------------------------------------------------
192 
193 namespace {
194 
195 class PyRegionIterator {
196 public:
197   PyRegionIterator(PyOperationRef operation)
198       : operation(std::move(operation)) {}
199 
200   PyRegionIterator &dunderIter() { return *this; }
201 
202   PyRegion dunderNext() {
203     operation->checkValid();
204     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
205       throw py::stop_iteration();
206     }
207     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
208     return PyRegion(operation, region);
209   }
210 
211   static void bind(py::module &m) {
212     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
213         .def("__iter__", &PyRegionIterator::dunderIter)
214         .def("__next__", &PyRegionIterator::dunderNext);
215   }
216 
217 private:
218   PyOperationRef operation;
219   int nextIndex = 0;
220 };
221 
222 /// Regions of an op are fixed length and indexed numerically so are represented
223 /// with a sequence-like container.
224 class PyRegionList {
225 public:
226   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
227 
228   intptr_t dunderLen() {
229     operation->checkValid();
230     return mlirOperationGetNumRegions(operation->get());
231   }
232 
233   PyRegion dunderGetItem(intptr_t index) {
234     // dunderLen checks validity.
235     if (index < 0 || index >= dunderLen()) {
236       throw SetPyError(PyExc_IndexError,
237                        "attempt to access out of bounds region");
238     }
239     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
240     return PyRegion(operation, region);
241   }
242 
243   static void bind(py::module &m) {
244     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
245         .def("__len__", &PyRegionList::dunderLen)
246         .def("__getitem__", &PyRegionList::dunderGetItem);
247   }
248 
249 private:
250   PyOperationRef operation;
251 };
252 
253 class PyBlockIterator {
254 public:
255   PyBlockIterator(PyOperationRef operation, MlirBlock next)
256       : operation(std::move(operation)), next(next) {}
257 
258   PyBlockIterator &dunderIter() { return *this; }
259 
260   PyBlock dunderNext() {
261     operation->checkValid();
262     if (mlirBlockIsNull(next)) {
263       throw py::stop_iteration();
264     }
265 
266     PyBlock returnBlock(operation, next);
267     next = mlirBlockGetNextInRegion(next);
268     return returnBlock;
269   }
270 
271   static void bind(py::module &m) {
272     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
273         .def("__iter__", &PyBlockIterator::dunderIter)
274         .def("__next__", &PyBlockIterator::dunderNext);
275   }
276 
277 private:
278   PyOperationRef operation;
279   MlirBlock next;
280 };
281 
282 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
283 /// we present them as a more full-featured list-like container but optimize
284 /// it for forward iteration. Blocks are always owned by a region.
285 class PyBlockList {
286 public:
287   PyBlockList(PyOperationRef operation, MlirRegion region)
288       : operation(std::move(operation)), region(region) {}
289 
290   PyBlockIterator dunderIter() {
291     operation->checkValid();
292     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
293   }
294 
295   intptr_t dunderLen() {
296     operation->checkValid();
297     intptr_t count = 0;
298     MlirBlock block = mlirRegionGetFirstBlock(region);
299     while (!mlirBlockIsNull(block)) {
300       count += 1;
301       block = mlirBlockGetNextInRegion(block);
302     }
303     return count;
304   }
305 
306   PyBlock dunderGetItem(intptr_t index) {
307     operation->checkValid();
308     if (index < 0) {
309       throw SetPyError(PyExc_IndexError,
310                        "attempt to access out of bounds block");
311     }
312     MlirBlock block = mlirRegionGetFirstBlock(region);
313     while (!mlirBlockIsNull(block)) {
314       if (index == 0) {
315         return PyBlock(operation, block);
316       }
317       block = mlirBlockGetNextInRegion(block);
318       index -= 1;
319     }
320     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
321   }
322 
323   PyBlock appendBlock(py::args pyArgTypes) {
324     operation->checkValid();
325     llvm::SmallVector<MlirType, 4> argTypes;
326     argTypes.reserve(pyArgTypes.size());
327     for (auto &pyArg : pyArgTypes) {
328       argTypes.push_back(pyArg.cast<PyType &>());
329     }
330 
331     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
332     mlirRegionAppendOwnedBlock(region, block);
333     return PyBlock(operation, block);
334   }
335 
336   static void bind(py::module &m) {
337     py::class_<PyBlockList>(m, "BlockList", py::module_local())
338         .def("__getitem__", &PyBlockList::dunderGetItem)
339         .def("__iter__", &PyBlockList::dunderIter)
340         .def("__len__", &PyBlockList::dunderLen)
341         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
342   }
343 
344 private:
345   PyOperationRef operation;
346   MlirRegion region;
347 };
348 
349 class PyOperationIterator {
350 public:
351   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
352       : parentOperation(std::move(parentOperation)), next(next) {}
353 
354   PyOperationIterator &dunderIter() { return *this; }
355 
356   py::object dunderNext() {
357     parentOperation->checkValid();
358     if (mlirOperationIsNull(next)) {
359       throw py::stop_iteration();
360     }
361 
362     PyOperationRef returnOperation =
363         PyOperation::forOperation(parentOperation->getContext(), next);
364     next = mlirOperationGetNextInBlock(next);
365     return returnOperation->createOpView();
366   }
367 
368   static void bind(py::module &m) {
369     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
370         .def("__iter__", &PyOperationIterator::dunderIter)
371         .def("__next__", &PyOperationIterator::dunderNext);
372   }
373 
374 private:
375   PyOperationRef parentOperation;
376   MlirOperation next;
377 };
378 
379 /// Operations are exposed by the C-API as a forward-only linked list. In
380 /// Python, we present them as a more full-featured list-like container but
381 /// optimize it for forward iteration. Iterable operations are always owned
382 /// by a block.
383 class PyOperationList {
384 public:
385   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
386       : parentOperation(std::move(parentOperation)), block(block) {}
387 
388   PyOperationIterator dunderIter() {
389     parentOperation->checkValid();
390     return PyOperationIterator(parentOperation,
391                                mlirBlockGetFirstOperation(block));
392   }
393 
394   intptr_t dunderLen() {
395     parentOperation->checkValid();
396     intptr_t count = 0;
397     MlirOperation childOp = mlirBlockGetFirstOperation(block);
398     while (!mlirOperationIsNull(childOp)) {
399       count += 1;
400       childOp = mlirOperationGetNextInBlock(childOp);
401     }
402     return count;
403   }
404 
405   py::object dunderGetItem(intptr_t index) {
406     parentOperation->checkValid();
407     if (index < 0) {
408       throw SetPyError(PyExc_IndexError,
409                        "attempt to access out of bounds operation");
410     }
411     MlirOperation childOp = mlirBlockGetFirstOperation(block);
412     while (!mlirOperationIsNull(childOp)) {
413       if (index == 0) {
414         return PyOperation::forOperation(parentOperation->getContext(), childOp)
415             ->createOpView();
416       }
417       childOp = mlirOperationGetNextInBlock(childOp);
418       index -= 1;
419     }
420     throw SetPyError(PyExc_IndexError,
421                      "attempt to access out of bounds operation");
422   }
423 
424   static void bind(py::module &m) {
425     py::class_<PyOperationList>(m, "OperationList", py::module_local())
426         .def("__getitem__", &PyOperationList::dunderGetItem)
427         .def("__iter__", &PyOperationList::dunderIter)
428         .def("__len__", &PyOperationList::dunderLen);
429   }
430 
431 private:
432   PyOperationRef parentOperation;
433   MlirBlock block;
434 };
435 
436 } // namespace
437 
438 //------------------------------------------------------------------------------
439 // PyMlirContext
440 //------------------------------------------------------------------------------
441 
442 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
443   py::gil_scoped_acquire acquire;
444   auto &liveContexts = getLiveContexts();
445   liveContexts[context.ptr] = this;
446 }
447 
448 PyMlirContext::~PyMlirContext() {
449   // Note that the only public way to construct an instance is via the
450   // forContext method, which always puts the associated handle into
451   // liveContexts.
452   py::gil_scoped_acquire acquire;
453   getLiveContexts().erase(context.ptr);
454   mlirContextDestroy(context);
455 }
456 
457 py::object PyMlirContext::getCapsule() {
458   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
459 }
460 
461 py::object PyMlirContext::createFromCapsule(py::object capsule) {
462   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
463   if (mlirContextIsNull(rawContext))
464     throw py::error_already_set();
465   return forContext(rawContext).releaseObject();
466 }
467 
468 PyMlirContext *PyMlirContext::createNewContextForInit() {
469   MlirContext context = mlirContextCreate();
470   mlirRegisterAllDialects(context);
471   return new PyMlirContext(context);
472 }
473 
474 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
475   py::gil_scoped_acquire acquire;
476   auto &liveContexts = getLiveContexts();
477   auto it = liveContexts.find(context.ptr);
478   if (it == liveContexts.end()) {
479     // Create.
480     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
481     py::object pyRef = py::cast(unownedContextWrapper);
482     assert(pyRef && "cast to py::object failed");
483     liveContexts[context.ptr] = unownedContextWrapper;
484     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
485   }
486   // Use existing.
487   py::object pyRef = py::cast(it->second);
488   return PyMlirContextRef(it->second, std::move(pyRef));
489 }
490 
491 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
492   static LiveContextMap liveContexts;
493   return liveContexts;
494 }
495 
496 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
497 
498 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
499 
500 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
501 
502 pybind11::object PyMlirContext::contextEnter() {
503   return PyThreadContextEntry::pushContext(*this);
504 }
505 
506 void PyMlirContext::contextExit(pybind11::object excType,
507                                 pybind11::object excVal,
508                                 pybind11::object excTb) {
509   PyThreadContextEntry::popContext(*this);
510 }
511 
512 PyMlirContext &DefaultingPyMlirContext::resolve() {
513   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
514   if (!context) {
515     throw SetPyError(
516         PyExc_RuntimeError,
517         "An MLIR function requires a Context but none was provided in the call "
518         "or from the surrounding environment. Either pass to the function with "
519         "a 'context=' argument or establish a default using 'with Context():'");
520   }
521   return *context;
522 }
523 
524 //------------------------------------------------------------------------------
525 // PyThreadContextEntry management
526 //------------------------------------------------------------------------------
527 
528 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
529   static thread_local std::vector<PyThreadContextEntry> stack;
530   return stack;
531 }
532 
533 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
534   auto &stack = getStack();
535   if (stack.empty())
536     return nullptr;
537   return &stack.back();
538 }
539 
540 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
541                                 py::object insertionPoint,
542                                 py::object location) {
543   auto &stack = getStack();
544   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
545                      std::move(location));
546   // If the new stack has more than one entry and the context of the new top
547   // entry matches the previous, copy the insertionPoint and location from the
548   // previous entry if missing from the new top entry.
549   if (stack.size() > 1) {
550     auto &prev = *(stack.rbegin() + 1);
551     auto &current = stack.back();
552     if (current.context.is(prev.context)) {
553       // Default non-context objects from the previous entry.
554       if (!current.insertionPoint)
555         current.insertionPoint = prev.insertionPoint;
556       if (!current.location)
557         current.location = prev.location;
558     }
559   }
560 }
561 
562 PyMlirContext *PyThreadContextEntry::getContext() {
563   if (!context)
564     return nullptr;
565   return py::cast<PyMlirContext *>(context);
566 }
567 
568 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
569   if (!insertionPoint)
570     return nullptr;
571   return py::cast<PyInsertionPoint *>(insertionPoint);
572 }
573 
574 PyLocation *PyThreadContextEntry::getLocation() {
575   if (!location)
576     return nullptr;
577   return py::cast<PyLocation *>(location);
578 }
579 
580 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getContext() : nullptr;
583 }
584 
585 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
586   auto *tos = getTopOfStack();
587   return tos ? tos->getInsertionPoint() : nullptr;
588 }
589 
590 PyLocation *PyThreadContextEntry::getDefaultLocation() {
591   auto *tos = getTopOfStack();
592   return tos ? tos->getLocation() : nullptr;
593 }
594 
595 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
596   py::object contextObj = py::cast(context);
597   push(FrameKind::Context, /*context=*/contextObj,
598        /*insertionPoint=*/py::object(),
599        /*location=*/py::object());
600   return contextObj;
601 }
602 
603 void PyThreadContextEntry::popContext(PyMlirContext &context) {
604   auto &stack = getStack();
605   if (stack.empty())
606     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
607   auto &tos = stack.back();
608   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
609     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
610   stack.pop_back();
611 }
612 
613 py::object
614 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
615   py::object contextObj =
616       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
617   py::object insertionPointObj = py::cast(insertionPoint);
618   push(FrameKind::InsertionPoint,
619        /*context=*/contextObj,
620        /*insertionPoint=*/insertionPointObj,
621        /*location=*/py::object());
622   return insertionPointObj;
623 }
624 
625 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
626   auto &stack = getStack();
627   if (stack.empty())
628     throw SetPyError(PyExc_RuntimeError,
629                      "Unbalanced InsertionPoint enter/exit");
630   auto &tos = stack.back();
631   if (tos.frameKind != FrameKind::InsertionPoint &&
632       tos.getInsertionPoint() != &insertionPoint)
633     throw SetPyError(PyExc_RuntimeError,
634                      "Unbalanced InsertionPoint enter/exit");
635   stack.pop_back();
636 }
637 
638 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
639   py::object contextObj = location.getContext().getObject();
640   py::object locationObj = py::cast(location);
641   push(FrameKind::Location, /*context=*/contextObj,
642        /*insertionPoint=*/py::object(),
643        /*location=*/locationObj);
644   return locationObj;
645 }
646 
647 void PyThreadContextEntry::popLocation(PyLocation &location) {
648   auto &stack = getStack();
649   if (stack.empty())
650     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
651   auto &tos = stack.back();
652   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
653     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
654   stack.pop_back();
655 }
656 
657 //------------------------------------------------------------------------------
658 // PyDialect, PyDialectDescriptor, PyDialects
659 //------------------------------------------------------------------------------
660 
661 MlirDialect PyDialects::getDialectForKey(const std::string &key,
662                                          bool attrError) {
663   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
664                                                     {key.data(), key.size()});
665   if (mlirDialectIsNull(dialect)) {
666     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
667                      Twine("Dialect '") + key + "' not found");
668   }
669   return dialect;
670 }
671 
672 //------------------------------------------------------------------------------
673 // PyLocation
674 //------------------------------------------------------------------------------
675 
676 py::object PyLocation::getCapsule() {
677   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
678 }
679 
680 PyLocation PyLocation::createFromCapsule(py::object capsule) {
681   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
682   if (mlirLocationIsNull(rawLoc))
683     throw py::error_already_set();
684   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
685                     rawLoc);
686 }
687 
688 py::object PyLocation::contextEnter() {
689   return PyThreadContextEntry::pushLocation(*this);
690 }
691 
692 void PyLocation::contextExit(py::object excType, py::object excVal,
693                              py::object excTb) {
694   PyThreadContextEntry::popLocation(*this);
695 }
696 
697 PyLocation &DefaultingPyLocation::resolve() {
698   auto *location = PyThreadContextEntry::getDefaultLocation();
699   if (!location) {
700     throw SetPyError(
701         PyExc_RuntimeError,
702         "An MLIR function requires a Location but none was provided in the "
703         "call or from the surrounding environment. Either pass to the function "
704         "with a 'loc=' argument or establish a default using 'with loc:'");
705   }
706   return *location;
707 }
708 
709 //------------------------------------------------------------------------------
710 // PyModule
711 //------------------------------------------------------------------------------
712 
713 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
714     : BaseContextObject(std::move(contextRef)), module(module) {}
715 
716 PyModule::~PyModule() {
717   py::gil_scoped_acquire acquire;
718   auto &liveModules = getContext()->liveModules;
719   assert(liveModules.count(module.ptr) == 1 &&
720          "destroying module not in live map");
721   liveModules.erase(module.ptr);
722   mlirModuleDestroy(module);
723 }
724 
725 PyModuleRef PyModule::forModule(MlirModule module) {
726   MlirContext context = mlirModuleGetContext(module);
727   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
728 
729   py::gil_scoped_acquire acquire;
730   auto &liveModules = contextRef->liveModules;
731   auto it = liveModules.find(module.ptr);
732   if (it == liveModules.end()) {
733     // Create.
734     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
735     // Note that the default return value policy on cast is automatic_reference,
736     // which does not take ownership (delete will not be called).
737     // Just be explicit.
738     py::object pyRef =
739         py::cast(unownedModule, py::return_value_policy::take_ownership);
740     unownedModule->handle = pyRef;
741     liveModules[module.ptr] =
742         std::make_pair(unownedModule->handle, unownedModule);
743     return PyModuleRef(unownedModule, std::move(pyRef));
744   }
745   // Use existing.
746   PyModule *existing = it->second.second;
747   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
748   return PyModuleRef(existing, std::move(pyRef));
749 }
750 
751 py::object PyModule::createFromCapsule(py::object capsule) {
752   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
753   if (mlirModuleIsNull(rawModule))
754     throw py::error_already_set();
755   return forModule(rawModule).releaseObject();
756 }
757 
758 py::object PyModule::getCapsule() {
759   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
760 }
761 
762 //------------------------------------------------------------------------------
763 // PyOperation
764 //------------------------------------------------------------------------------
765 
766 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
767     : BaseContextObject(std::move(contextRef)), operation(operation) {}
768 
769 PyOperation::~PyOperation() {
770   // If the operation has already been invalidated there is nothing to do.
771   if (!valid)
772     return;
773   auto &liveOperations = getContext()->liveOperations;
774   assert(liveOperations.count(operation.ptr) == 1 &&
775          "destroying operation not in live map");
776   liveOperations.erase(operation.ptr);
777   if (!isAttached()) {
778     mlirOperationDestroy(operation);
779   }
780 }
781 
782 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
783                                            MlirOperation operation,
784                                            py::object parentKeepAlive) {
785   auto &liveOperations = contextRef->liveOperations;
786   // Create.
787   PyOperation *unownedOperation =
788       new PyOperation(std::move(contextRef), operation);
789   // Note that the default return value policy on cast is automatic_reference,
790   // which does not take ownership (delete will not be called).
791   // Just be explicit.
792   py::object pyRef =
793       py::cast(unownedOperation, py::return_value_policy::take_ownership);
794   unownedOperation->handle = pyRef;
795   if (parentKeepAlive) {
796     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
797   }
798   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
799   return PyOperationRef(unownedOperation, std::move(pyRef));
800 }
801 
802 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
803                                          MlirOperation operation,
804                                          py::object parentKeepAlive) {
805   auto &liveOperations = contextRef->liveOperations;
806   auto it = liveOperations.find(operation.ptr);
807   if (it == liveOperations.end()) {
808     // Create.
809     return createInstance(std::move(contextRef), operation,
810                           std::move(parentKeepAlive));
811   }
812   // Use existing.
813   PyOperation *existing = it->second.second;
814   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
815   return PyOperationRef(existing, std::move(pyRef));
816 }
817 
818 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
819                                            MlirOperation operation,
820                                            py::object parentKeepAlive) {
821   auto &liveOperations = contextRef->liveOperations;
822   assert(liveOperations.count(operation.ptr) == 0 &&
823          "cannot create detached operation that already exists");
824   (void)liveOperations;
825 
826   PyOperationRef created = createInstance(std::move(contextRef), operation,
827                                           std::move(parentKeepAlive));
828   created->attached = false;
829   return created;
830 }
831 
832 void PyOperation::checkValid() const {
833   if (!valid) {
834     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
835   }
836 }
837 
838 void PyOperationBase::print(py::object fileObject, bool binary,
839                             llvm::Optional<int64_t> largeElementsLimit,
840                             bool enableDebugInfo, bool prettyDebugInfo,
841                             bool printGenericOpForm, bool useLocalScope,
842                             bool assumeVerified) {
843   PyOperation &operation = getOperation();
844   operation.checkValid();
845   if (fileObject.is_none())
846     fileObject = py::module::import("sys").attr("stdout");
847 
848   if (!assumeVerified && !printGenericOpForm &&
849       !mlirOperationVerify(operation)) {
850     std::string message("// Verification failed, printing generic form\n");
851     if (binary) {
852       fileObject.attr("write")(py::bytes(message));
853     } else {
854       fileObject.attr("write")(py::str(message));
855     }
856     printGenericOpForm = true;
857   }
858 
859   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
860   if (largeElementsLimit)
861     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
862   if (enableDebugInfo)
863     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
864   if (printGenericOpForm)
865     mlirOpPrintingFlagsPrintGenericOpForm(flags);
866 
867   PyFileAccumulator accum(fileObject, binary);
868   py::gil_scoped_release();
869   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
870                               accum.getUserData());
871   mlirOpPrintingFlagsDestroy(flags);
872 }
873 
874 py::object PyOperationBase::getAsm(bool binary,
875                                    llvm::Optional<int64_t> largeElementsLimit,
876                                    bool enableDebugInfo, bool prettyDebugInfo,
877                                    bool printGenericOpForm, bool useLocalScope,
878                                    bool assumeVerified) {
879   py::object fileObject;
880   if (binary) {
881     fileObject = py::module::import("io").attr("BytesIO")();
882   } else {
883     fileObject = py::module::import("io").attr("StringIO")();
884   }
885   print(fileObject, /*binary=*/binary,
886         /*largeElementsLimit=*/largeElementsLimit,
887         /*enableDebugInfo=*/enableDebugInfo,
888         /*prettyDebugInfo=*/prettyDebugInfo,
889         /*printGenericOpForm=*/printGenericOpForm,
890         /*useLocalScope=*/useLocalScope,
891         /*assumeVerified=*/assumeVerified);
892 
893   return fileObject.attr("getvalue")();
894 }
895 
896 void PyOperationBase::moveAfter(PyOperationBase &other) {
897   PyOperation &operation = getOperation();
898   PyOperation &otherOp = other.getOperation();
899   operation.checkValid();
900   otherOp.checkValid();
901   mlirOperationMoveAfter(operation, otherOp);
902   operation.parentKeepAlive = otherOp.parentKeepAlive;
903 }
904 
905 void PyOperationBase::moveBefore(PyOperationBase &other) {
906   PyOperation &operation = getOperation();
907   PyOperation &otherOp = other.getOperation();
908   operation.checkValid();
909   otherOp.checkValid();
910   mlirOperationMoveBefore(operation, otherOp);
911   operation.parentKeepAlive = otherOp.parentKeepAlive;
912 }
913 
914 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
915   checkValid();
916   if (!isAttached())
917     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
918   MlirOperation operation = mlirOperationGetParentOperation(get());
919   if (mlirOperationIsNull(operation))
920     return {};
921   return PyOperation::forOperation(getContext(), operation);
922 }
923 
924 PyBlock PyOperation::getBlock() {
925   checkValid();
926   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
927   MlirBlock block = mlirOperationGetBlock(get());
928   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
929   assert(parentOperation && "Operation has no parent");
930   return PyBlock{std::move(*parentOperation), block};
931 }
932 
933 py::object PyOperation::getCapsule() {
934   checkValid();
935   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
936 }
937 
938 py::object PyOperation::createFromCapsule(py::object capsule) {
939   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
940   if (mlirOperationIsNull(rawOperation))
941     throw py::error_already_set();
942   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
943   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
944       .releaseObject();
945 }
946 
947 py::object PyOperation::create(
948     std::string name, llvm::Optional<std::vector<PyType *>> results,
949     llvm::Optional<std::vector<PyValue *>> operands,
950     llvm::Optional<py::dict> attributes,
951     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
952     DefaultingPyLocation location, py::object maybeIp) {
953   llvm::SmallVector<MlirValue, 4> mlirOperands;
954   llvm::SmallVector<MlirType, 4> mlirResults;
955   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
956   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
957 
958   // General parameter validation.
959   if (regions < 0)
960     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
961 
962   // Unpack/validate operands.
963   if (operands) {
964     mlirOperands.reserve(operands->size());
965     for (PyValue *operand : *operands) {
966       if (!operand)
967         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
968       mlirOperands.push_back(operand->get());
969     }
970   }
971 
972   // Unpack/validate results.
973   if (results) {
974     mlirResults.reserve(results->size());
975     for (PyType *result : *results) {
976       // TODO: Verify result type originate from the same context.
977       if (!result)
978         throw SetPyError(PyExc_ValueError, "result type cannot be None");
979       mlirResults.push_back(*result);
980     }
981   }
982   // Unpack/validate attributes.
983   if (attributes) {
984     mlirAttributes.reserve(attributes->size());
985     for (auto &it : *attributes) {
986       std::string key;
987       try {
988         key = it.first.cast<std::string>();
989       } catch (py::cast_error &err) {
990         std::string msg = "Invalid attribute key (not a string) when "
991                           "attempting to create the operation \"" +
992                           name + "\" (" + err.what() + ")";
993         throw py::cast_error(msg);
994       }
995       try {
996         auto &attribute = it.second.cast<PyAttribute &>();
997         // TODO: Verify attribute originates from the same context.
998         mlirAttributes.emplace_back(std::move(key), attribute);
999       } catch (py::reference_cast_error &) {
1000         // This exception seems thrown when the value is "None".
1001         std::string msg =
1002             "Found an invalid (`None`?) attribute value for the key \"" + key +
1003             "\" when attempting to create the operation \"" + name + "\"";
1004         throw py::cast_error(msg);
1005       } catch (py::cast_error &err) {
1006         std::string msg = "Invalid attribute value for the key \"" + key +
1007                           "\" when attempting to create the operation \"" +
1008                           name + "\" (" + err.what() + ")";
1009         throw py::cast_error(msg);
1010       }
1011     }
1012   }
1013   // Unpack/validate successors.
1014   if (successors) {
1015     mlirSuccessors.reserve(successors->size());
1016     for (auto *successor : *successors) {
1017       // TODO: Verify successor originate from the same context.
1018       if (!successor)
1019         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1020       mlirSuccessors.push_back(successor->get());
1021     }
1022   }
1023 
1024   // Apply unpacked/validated to the operation state. Beyond this
1025   // point, exceptions cannot be thrown or else the state will leak.
1026   MlirOperationState state =
1027       mlirOperationStateGet(toMlirStringRef(name), location);
1028   if (!mlirOperands.empty())
1029     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1030                                   mlirOperands.data());
1031   if (!mlirResults.empty())
1032     mlirOperationStateAddResults(&state, mlirResults.size(),
1033                                  mlirResults.data());
1034   if (!mlirAttributes.empty()) {
1035     // Note that the attribute names directly reference bytes in
1036     // mlirAttributes, so that vector must not be changed from here
1037     // on.
1038     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1039     mlirNamedAttributes.reserve(mlirAttributes.size());
1040     for (auto &it : mlirAttributes)
1041       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1042           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1043                             toMlirStringRef(it.first)),
1044           it.second));
1045     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1046                                     mlirNamedAttributes.data());
1047   }
1048   if (!mlirSuccessors.empty())
1049     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1050                                     mlirSuccessors.data());
1051   if (regions) {
1052     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1053     mlirRegions.resize(regions);
1054     for (int i = 0; i < regions; ++i)
1055       mlirRegions[i] = mlirRegionCreate();
1056     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1057                                       mlirRegions.data());
1058   }
1059 
1060   // Construct the operation.
1061   MlirOperation operation = mlirOperationCreate(&state);
1062   PyOperationRef created =
1063       PyOperation::createDetached(location->getContext(), operation);
1064 
1065   // InsertPoint active?
1066   if (!maybeIp.is(py::cast(false))) {
1067     PyInsertionPoint *ip;
1068     if (maybeIp.is_none()) {
1069       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1070     } else {
1071       ip = py::cast<PyInsertionPoint *>(maybeIp);
1072     }
1073     if (ip)
1074       ip->insert(*created.get());
1075   }
1076 
1077   return created->createOpView();
1078 }
1079 
1080 py::object PyOperation::createOpView() {
1081   checkValid();
1082   MlirIdentifier ident = mlirOperationGetName(get());
1083   MlirStringRef identStr = mlirIdentifierStr(ident);
1084   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1085       StringRef(identStr.data, identStr.length));
1086   if (opViewClass)
1087     return (*opViewClass)(getRef().getObject());
1088   return py::cast(PyOpView(getRef().getObject()));
1089 }
1090 
1091 void PyOperation::erase() {
1092   checkValid();
1093   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1094   // Python reference to a child operation is live. All children should also
1095   // have their `valid` bit set to false.
1096   auto &liveOperations = getContext()->liveOperations;
1097   if (liveOperations.count(operation.ptr))
1098     liveOperations.erase(operation.ptr);
1099   mlirOperationDestroy(operation);
1100   valid = false;
1101 }
1102 
1103 //------------------------------------------------------------------------------
1104 // PyOpView
1105 //------------------------------------------------------------------------------
1106 
1107 py::object
1108 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1109                        py::list operandList,
1110                        llvm::Optional<py::dict> attributes,
1111                        llvm::Optional<std::vector<PyBlock *>> successors,
1112                        llvm::Optional<int> regions,
1113                        DefaultingPyLocation location, py::object maybeIp) {
1114   PyMlirContextRef context = location->getContext();
1115   // Class level operation construction metadata.
1116   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1117   // Operand and result segment specs are either none, which does no
1118   // variadic unpacking, or a list of ints with segment sizes, where each
1119   // element is either a positive number (typically 1 for a scalar) or -1 to
1120   // indicate that it is derived from the length of the same-indexed operand
1121   // or result (implying that it is a list at that position).
1122   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1123   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1124 
1125   std::vector<uint32_t> operandSegmentLengths;
1126   std::vector<uint32_t> resultSegmentLengths;
1127 
1128   // Validate/determine region count.
1129   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1130   int opMinRegionCount = std::get<0>(opRegionSpec);
1131   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1132   if (!regions) {
1133     regions = opMinRegionCount;
1134   }
1135   if (*regions < opMinRegionCount) {
1136     throw py::value_error(
1137         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1138          llvm::Twine(opMinRegionCount) +
1139          " regions but was built with regions=" + llvm::Twine(*regions))
1140             .str());
1141   }
1142   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1143     throw py::value_error(
1144         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1145          llvm::Twine(opMinRegionCount) +
1146          " regions but was built with regions=" + llvm::Twine(*regions))
1147             .str());
1148   }
1149 
1150   // Unpack results.
1151   std::vector<PyType *> resultTypes;
1152   resultTypes.reserve(resultTypeList.size());
1153   if (resultSegmentSpecObj.is_none()) {
1154     // Non-variadic result unpacking.
1155     for (auto it : llvm::enumerate(resultTypeList)) {
1156       try {
1157         resultTypes.push_back(py::cast<PyType *>(it.value()));
1158         if (!resultTypes.back())
1159           throw py::cast_error();
1160       } catch (py::cast_error &err) {
1161         throw py::value_error((llvm::Twine("Result ") +
1162                                llvm::Twine(it.index()) + " of operation \"" +
1163                                name + "\" must be a Type (" + err.what() + ")")
1164                                   .str());
1165       }
1166     }
1167   } else {
1168     // Sized result unpacking.
1169     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1170     if (resultSegmentSpec.size() != resultTypeList.size()) {
1171       throw py::value_error((llvm::Twine("Operation \"") + name +
1172                              "\" requires " +
1173                              llvm::Twine(resultSegmentSpec.size()) +
1174                              " result segments but was provided " +
1175                              llvm::Twine(resultTypeList.size()))
1176                                 .str());
1177     }
1178     resultSegmentLengths.reserve(resultTypeList.size());
1179     for (auto it :
1180          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1181       int segmentSpec = std::get<1>(it.value());
1182       if (segmentSpec == 1 || segmentSpec == 0) {
1183         // Unpack unary element.
1184         try {
1185           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1186           if (resultType) {
1187             resultTypes.push_back(resultType);
1188             resultSegmentLengths.push_back(1);
1189           } else if (segmentSpec == 0) {
1190             // Allowed to be optional.
1191             resultSegmentLengths.push_back(0);
1192           } else {
1193             throw py::cast_error("was None and result is not optional");
1194           }
1195         } catch (py::cast_error &err) {
1196           throw py::value_error((llvm::Twine("Result ") +
1197                                  llvm::Twine(it.index()) + " of operation \"" +
1198                                  name + "\" must be a Type (" + err.what() +
1199                                  ")")
1200                                     .str());
1201         }
1202       } else if (segmentSpec == -1) {
1203         // Unpack sequence by appending.
1204         try {
1205           if (std::get<0>(it.value()).is_none()) {
1206             // Treat it as an empty list.
1207             resultSegmentLengths.push_back(0);
1208           } else {
1209             // Unpack the list.
1210             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1211             for (py::object segmentItem : segment) {
1212               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1213               if (!resultTypes.back()) {
1214                 throw py::cast_error("contained a None item");
1215               }
1216             }
1217             resultSegmentLengths.push_back(segment.size());
1218           }
1219         } catch (std::exception &err) {
1220           // NOTE: Sloppy to be using a catch-all here, but there are at least
1221           // three different unrelated exceptions that can be thrown in the
1222           // above "casts". Just keep the scope above small and catch them all.
1223           throw py::value_error((llvm::Twine("Result ") +
1224                                  llvm::Twine(it.index()) + " of operation \"" +
1225                                  name + "\" must be a Sequence of Types (" +
1226                                  err.what() + ")")
1227                                     .str());
1228         }
1229       } else {
1230         throw py::value_error("Unexpected segment spec");
1231       }
1232     }
1233   }
1234 
1235   // Unpack operands.
1236   std::vector<PyValue *> operands;
1237   operands.reserve(operands.size());
1238   if (operandSegmentSpecObj.is_none()) {
1239     // Non-sized operand unpacking.
1240     for (auto it : llvm::enumerate(operandList)) {
1241       try {
1242         operands.push_back(py::cast<PyValue *>(it.value()));
1243         if (!operands.back())
1244           throw py::cast_error();
1245       } catch (py::cast_error &err) {
1246         throw py::value_error((llvm::Twine("Operand ") +
1247                                llvm::Twine(it.index()) + " of operation \"" +
1248                                name + "\" must be a Value (" + err.what() + ")")
1249                                   .str());
1250       }
1251     }
1252   } else {
1253     // Sized operand unpacking.
1254     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1255     if (operandSegmentSpec.size() != operandList.size()) {
1256       throw py::value_error((llvm::Twine("Operation \"") + name +
1257                              "\" requires " +
1258                              llvm::Twine(operandSegmentSpec.size()) +
1259                              "operand segments but was provided " +
1260                              llvm::Twine(operandList.size()))
1261                                 .str());
1262     }
1263     operandSegmentLengths.reserve(operandList.size());
1264     for (auto it :
1265          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1266       int segmentSpec = std::get<1>(it.value());
1267       if (segmentSpec == 1 || segmentSpec == 0) {
1268         // Unpack unary element.
1269         try {
1270           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1271           if (operandValue) {
1272             operands.push_back(operandValue);
1273             operandSegmentLengths.push_back(1);
1274           } else if (segmentSpec == 0) {
1275             // Allowed to be optional.
1276             operandSegmentLengths.push_back(0);
1277           } else {
1278             throw py::cast_error("was None and operand is not optional");
1279           }
1280         } catch (py::cast_error &err) {
1281           throw py::value_error((llvm::Twine("Operand ") +
1282                                  llvm::Twine(it.index()) + " of operation \"" +
1283                                  name + "\" must be a Value (" + err.what() +
1284                                  ")")
1285                                     .str());
1286         }
1287       } else if (segmentSpec == -1) {
1288         // Unpack sequence by appending.
1289         try {
1290           if (std::get<0>(it.value()).is_none()) {
1291             // Treat it as an empty list.
1292             operandSegmentLengths.push_back(0);
1293           } else {
1294             // Unpack the list.
1295             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1296             for (py::object segmentItem : segment) {
1297               operands.push_back(py::cast<PyValue *>(segmentItem));
1298               if (!operands.back()) {
1299                 throw py::cast_error("contained a None item");
1300               }
1301             }
1302             operandSegmentLengths.push_back(segment.size());
1303           }
1304         } catch (std::exception &err) {
1305           // NOTE: Sloppy to be using a catch-all here, but there are at least
1306           // three different unrelated exceptions that can be thrown in the
1307           // above "casts". Just keep the scope above small and catch them all.
1308           throw py::value_error((llvm::Twine("Operand ") +
1309                                  llvm::Twine(it.index()) + " of operation \"" +
1310                                  name + "\" must be a Sequence of Values (" +
1311                                  err.what() + ")")
1312                                     .str());
1313         }
1314       } else {
1315         throw py::value_error("Unexpected segment spec");
1316       }
1317     }
1318   }
1319 
1320   // Merge operand/result segment lengths into attributes if needed.
1321   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1322     // Dup.
1323     if (attributes) {
1324       attributes = py::dict(*attributes);
1325     } else {
1326       attributes = py::dict();
1327     }
1328     if (attributes->contains("result_segment_sizes") ||
1329         attributes->contains("operand_segment_sizes")) {
1330       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1331                             "'operand_segment_sizes' attribute is unsupported. "
1332                             "Use Operation.create for such low-level access.");
1333     }
1334 
1335     // Add result_segment_sizes attribute.
1336     if (!resultSegmentLengths.empty()) {
1337       int64_t size = resultSegmentLengths.size();
1338       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1339           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1340           resultSegmentLengths.size(), resultSegmentLengths.data());
1341       (*attributes)["result_segment_sizes"] =
1342           PyAttribute(context, segmentLengthAttr);
1343     }
1344 
1345     // Add operand_segment_sizes attribute.
1346     if (!operandSegmentLengths.empty()) {
1347       int64_t size = operandSegmentLengths.size();
1348       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1349           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1350           operandSegmentLengths.size(), operandSegmentLengths.data());
1351       (*attributes)["operand_segment_sizes"] =
1352           PyAttribute(context, segmentLengthAttr);
1353     }
1354   }
1355 
1356   // Delegate to create.
1357   return PyOperation::create(std::move(name),
1358                              /*results=*/std::move(resultTypes),
1359                              /*operands=*/std::move(operands),
1360                              /*attributes=*/std::move(attributes),
1361                              /*successors=*/std::move(successors),
1362                              /*regions=*/*regions, location, maybeIp);
1363 }
1364 
1365 PyOpView::PyOpView(py::object operationObject)
1366     // Casting through the PyOperationBase base-class and then back to the
1367     // Operation lets us accept any PyOperationBase subclass.
1368     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1369       operationObject(operation.getRef().getObject()) {}
1370 
1371 py::object PyOpView::createRawSubclass(py::object userClass) {
1372   // This is... a little gross. The typical pattern is to have a pure python
1373   // class that extends OpView like:
1374   //   class AddFOp(_cext.ir.OpView):
1375   //     def __init__(self, loc, lhs, rhs):
1376   //       operation = loc.context.create_operation(
1377   //           "addf", lhs, rhs, results=[lhs.type])
1378   //       super().__init__(operation)
1379   //
1380   // I.e. The goal of the user facing type is to provide a nice constructor
1381   // that has complete freedom for the op under construction. This is at odds
1382   // with our other desire to sometimes create this object by just passing an
1383   // operation (to initialize the base class). We could do *arg and **kwargs
1384   // munging to try to make it work, but instead, we synthesize a new class
1385   // on the fly which extends this user class (AddFOp in this example) and
1386   // *give it* the base class's __init__ method, thus bypassing the
1387   // intermediate subclass's __init__ method entirely. While slightly,
1388   // underhanded, this is safe/legal because the type hierarchy has not changed
1389   // (we just added a new leaf) and we aren't mucking around with __new__.
1390   // Typically, this new class will be stored on the original as "_Raw" and will
1391   // be used for casts and other things that need a variant of the class that
1392   // is initialized purely from an operation.
1393   py::object parentMetaclass =
1394       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1395   py::dict attributes;
1396   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1397   // now.
1398   //   auto opViewType = py::type::of<PyOpView>();
1399   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1400   attributes["__init__"] = opViewType.attr("__init__");
1401   py::str origName = userClass.attr("__name__");
1402   py::str newName = py::str("_") + origName;
1403   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1404 }
1405 
1406 //------------------------------------------------------------------------------
1407 // PyInsertionPoint.
1408 //------------------------------------------------------------------------------
1409 
1410 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1411 
1412 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1413     : refOperation(beforeOperationBase.getOperation().getRef()),
1414       block((*refOperation)->getBlock()) {}
1415 
1416 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1417   PyOperation &operation = operationBase.getOperation();
1418   if (operation.isAttached())
1419     throw SetPyError(PyExc_ValueError,
1420                      "Attempt to insert operation that is already attached");
1421   block.getParentOperation()->checkValid();
1422   MlirOperation beforeOp = {nullptr};
1423   if (refOperation) {
1424     // Insert before operation.
1425     (*refOperation)->checkValid();
1426     beforeOp = (*refOperation)->get();
1427   } else {
1428     // Insert at end (before null) is only valid if the block does not
1429     // already end in a known terminator (violating this will cause assertion
1430     // failures later).
1431     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1432       throw py::index_error("Cannot insert operation at the end of a block "
1433                             "that already has a terminator. Did you mean to "
1434                             "use 'InsertionPoint.at_block_terminator(block)' "
1435                             "versus 'InsertionPoint(block)'?");
1436     }
1437   }
1438   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1439   operation.setAttached();
1440 }
1441 
1442 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1443   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1444   if (mlirOperationIsNull(firstOp)) {
1445     // Just insert at end.
1446     return PyInsertionPoint(block);
1447   }
1448 
1449   // Insert before first op.
1450   PyOperationRef firstOpRef = PyOperation::forOperation(
1451       block.getParentOperation()->getContext(), firstOp);
1452   return PyInsertionPoint{block, std::move(firstOpRef)};
1453 }
1454 
1455 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1456   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1457   if (mlirOperationIsNull(terminator))
1458     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1459   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1460       block.getParentOperation()->getContext(), terminator);
1461   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1462 }
1463 
1464 py::object PyInsertionPoint::contextEnter() {
1465   return PyThreadContextEntry::pushInsertionPoint(*this);
1466 }
1467 
1468 void PyInsertionPoint::contextExit(pybind11::object excType,
1469                                    pybind11::object excVal,
1470                                    pybind11::object excTb) {
1471   PyThreadContextEntry::popInsertionPoint(*this);
1472 }
1473 
1474 //------------------------------------------------------------------------------
1475 // PyAttribute.
1476 //------------------------------------------------------------------------------
1477 
1478 bool PyAttribute::operator==(const PyAttribute &other) {
1479   return mlirAttributeEqual(attr, other.attr);
1480 }
1481 
1482 py::object PyAttribute::getCapsule() {
1483   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1484 }
1485 
1486 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1487   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1488   if (mlirAttributeIsNull(rawAttr))
1489     throw py::error_already_set();
1490   return PyAttribute(
1491       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1492 }
1493 
1494 //------------------------------------------------------------------------------
1495 // PyNamedAttribute.
1496 //------------------------------------------------------------------------------
1497 
1498 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1499     : ownedName(new std::string(std::move(ownedName))) {
1500   namedAttr = mlirNamedAttributeGet(
1501       mlirIdentifierGet(mlirAttributeGetContext(attr),
1502                         toMlirStringRef(*this->ownedName)),
1503       attr);
1504 }
1505 
1506 //------------------------------------------------------------------------------
1507 // PyType.
1508 //------------------------------------------------------------------------------
1509 
1510 bool PyType::operator==(const PyType &other) {
1511   return mlirTypeEqual(type, other.type);
1512 }
1513 
1514 py::object PyType::getCapsule() {
1515   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1516 }
1517 
1518 PyType PyType::createFromCapsule(py::object capsule) {
1519   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1520   if (mlirTypeIsNull(rawType))
1521     throw py::error_already_set();
1522   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1523                 rawType);
1524 }
1525 
1526 //------------------------------------------------------------------------------
1527 // PyValue and subclases.
1528 //------------------------------------------------------------------------------
1529 
1530 pybind11::object PyValue::getCapsule() {
1531   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1532 }
1533 
1534 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1535   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1536   if (mlirValueIsNull(value))
1537     throw py::error_already_set();
1538   MlirOperation owner;
1539   if (mlirValueIsAOpResult(value))
1540     owner = mlirOpResultGetOwner(value);
1541   if (mlirValueIsABlockArgument(value))
1542     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1543   if (mlirOperationIsNull(owner))
1544     throw py::error_already_set();
1545   MlirContext ctx = mlirOperationGetContext(owner);
1546   PyOperationRef ownerRef =
1547       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1548   return PyValue(ownerRef, value);
1549 }
1550 
1551 //------------------------------------------------------------------------------
1552 // PySymbolTable.
1553 //------------------------------------------------------------------------------
1554 
1555 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1556     : operation(operation.getOperation().getRef()) {
1557   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1558   if (mlirSymbolTableIsNull(symbolTable)) {
1559     throw py::cast_error("Operation is not a Symbol Table.");
1560   }
1561 }
1562 
1563 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1564   operation->checkValid();
1565   MlirOperation symbol = mlirSymbolTableLookup(
1566       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1567   if (mlirOperationIsNull(symbol))
1568     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1569 
1570   return PyOperation::forOperation(operation->getContext(), symbol,
1571                                    operation.getObject())
1572       ->createOpView();
1573 }
1574 
1575 void PySymbolTable::erase(PyOperationBase &symbol) {
1576   operation->checkValid();
1577   symbol.getOperation().checkValid();
1578   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1579   // The operation is also erased, so we must invalidate it. There may be Python
1580   // references to this operation so we don't want to delete it from the list of
1581   // live operations here.
1582   symbol.getOperation().valid = false;
1583 }
1584 
1585 void PySymbolTable::dunderDel(const std::string &name) {
1586   py::object operation = dunderGetItem(name);
1587   erase(py::cast<PyOperationBase &>(operation));
1588 }
1589 
1590 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1591   operation->checkValid();
1592   symbol.getOperation().checkValid();
1593   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1594       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1595   if (mlirAttributeIsNull(symbolAttr))
1596     throw py::value_error("Expected operation to have a symbol name.");
1597   return PyAttribute(
1598       symbol.getOperation().getContext(),
1599       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1600 }
1601 
1602 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1603   // Op must already be a symbol.
1604   PyOperation &operation = symbol.getOperation();
1605   operation.checkValid();
1606   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1607   MlirAttribute existingNameAttr =
1608       mlirOperationGetAttributeByName(operation.get(), attrName);
1609   if (mlirAttributeIsNull(existingNameAttr))
1610     throw py::value_error("Expected operation to have a symbol name.");
1611   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1612 }
1613 
1614 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1615                                   const std::string &name) {
1616   // Op must already be a symbol.
1617   PyOperation &operation = symbol.getOperation();
1618   operation.checkValid();
1619   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1620   MlirAttribute existingNameAttr =
1621       mlirOperationGetAttributeByName(operation.get(), attrName);
1622   if (mlirAttributeIsNull(existingNameAttr))
1623     throw py::value_error("Expected operation to have a symbol name.");
1624   MlirAttribute newNameAttr =
1625       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1626   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1627 }
1628 
1629 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1630   PyOperation &operation = symbol.getOperation();
1631   operation.checkValid();
1632   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1633   MlirAttribute existingVisAttr =
1634       mlirOperationGetAttributeByName(operation.get(), attrName);
1635   if (mlirAttributeIsNull(existingVisAttr))
1636     throw py::value_error("Expected operation to have a symbol visibility.");
1637   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1638 }
1639 
1640 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1641                                   const std::string &visibility) {
1642   if (visibility != "public" && visibility != "private" &&
1643       visibility != "nested")
1644     throw py::value_error(
1645         "Expected visibility to be 'public', 'private' or 'nested'");
1646   PyOperation &operation = symbol.getOperation();
1647   operation.checkValid();
1648   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1649   MlirAttribute existingVisAttr =
1650       mlirOperationGetAttributeByName(operation.get(), attrName);
1651   if (mlirAttributeIsNull(existingVisAttr))
1652     throw py::value_error("Expected operation to have a symbol visibility.");
1653   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1654                                                toMlirStringRef(visibility));
1655   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1656 }
1657 
1658 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1659                                          const std::string &newSymbol,
1660                                          PyOperationBase &from) {
1661   PyOperation &fromOperation = from.getOperation();
1662   fromOperation.checkValid();
1663   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1664           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1665           from.getOperation())))
1666 
1667     throw py::value_error("Symbol rename failed");
1668 }
1669 
1670 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1671                                      bool allSymUsesVisible,
1672                                      py::object callback) {
1673   PyOperation &fromOperation = from.getOperation();
1674   fromOperation.checkValid();
1675   struct UserData {
1676     PyMlirContextRef context;
1677     py::object callback;
1678     bool gotException;
1679     std::string exceptionWhat;
1680     py::object exceptionType;
1681   };
1682   UserData userData{
1683       fromOperation.getContext(), std::move(callback), false, {}, {}};
1684   mlirSymbolTableWalkSymbolTables(
1685       fromOperation.get(), allSymUsesVisible,
1686       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1687         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1688         auto pyFoundOp =
1689             PyOperation::forOperation(calleeUserData->context, foundOp);
1690         if (calleeUserData->gotException)
1691           return;
1692         try {
1693           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1694         } catch (py::error_already_set &e) {
1695           calleeUserData->gotException = true;
1696           calleeUserData->exceptionWhat = e.what();
1697           calleeUserData->exceptionType = e.type();
1698         }
1699       },
1700       static_cast<void *>(&userData));
1701   if (userData.gotException) {
1702     std::string message("Exception raised in callback: ");
1703     message.append(userData.exceptionWhat);
1704     throw std::runtime_error(std::move(message));
1705   }
1706 }
1707 
1708 namespace {
1709 /// CRTP base class for Python MLIR values that subclass Value and should be
1710 /// castable from it. The value hierarchy is one level deep and is not supposed
1711 /// to accommodate other levels unless core MLIR changes.
1712 template <typename DerivedTy>
1713 class PyConcreteValue : public PyValue {
1714 public:
1715   // Derived classes must define statics for:
1716   //   IsAFunctionTy isaFunction
1717   //   const char *pyClassName
1718   // and redefine bindDerived.
1719   using ClassTy = py::class_<DerivedTy, PyValue>;
1720   using IsAFunctionTy = bool (*)(MlirValue);
1721 
1722   PyConcreteValue() = default;
1723   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1724       : PyValue(operationRef, value) {}
1725   PyConcreteValue(PyValue &orig)
1726       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1727 
1728   /// Attempts to cast the original value to the derived type and throws on
1729   /// type mismatches.
1730   static MlirValue castFrom(PyValue &orig) {
1731     if (!DerivedTy::isaFunction(orig.get())) {
1732       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1733       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1734                                              DerivedTy::pyClassName +
1735                                              " (from " + origRepr + ")");
1736     }
1737     return orig.get();
1738   }
1739 
1740   /// Binds the Python module objects to functions of this class.
1741   static void bind(py::module &m) {
1742     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1743     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1744     cls.def_static(
1745         "isinstance",
1746         [](PyValue &otherValue) -> bool {
1747           return DerivedTy::isaFunction(otherValue);
1748         },
1749         py::arg("other_value"));
1750     DerivedTy::bindDerived(cls);
1751   }
1752 
1753   /// Implemented by derived classes to add methods to the Python subclass.
1754   static void bindDerived(ClassTy &m) {}
1755 };
1756 
1757 /// Python wrapper for MlirBlockArgument.
1758 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1759 public:
1760   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1761   static constexpr const char *pyClassName = "BlockArgument";
1762   using PyConcreteValue::PyConcreteValue;
1763 
1764   static void bindDerived(ClassTy &c) {
1765     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1766       return PyBlock(self.getParentOperation(),
1767                      mlirBlockArgumentGetOwner(self.get()));
1768     });
1769     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1770       return mlirBlockArgumentGetArgNumber(self.get());
1771     });
1772     c.def(
1773         "set_type",
1774         [](PyBlockArgument &self, PyType type) {
1775           return mlirBlockArgumentSetType(self.get(), type);
1776         },
1777         py::arg("type"));
1778   }
1779 };
1780 
1781 /// Python wrapper for MlirOpResult.
1782 class PyOpResult : public PyConcreteValue<PyOpResult> {
1783 public:
1784   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1785   static constexpr const char *pyClassName = "OpResult";
1786   using PyConcreteValue::PyConcreteValue;
1787 
1788   static void bindDerived(ClassTy &c) {
1789     c.def_property_readonly("owner", [](PyOpResult &self) {
1790       assert(
1791           mlirOperationEqual(self.getParentOperation()->get(),
1792                              mlirOpResultGetOwner(self.get())) &&
1793           "expected the owner of the value in Python to match that in the IR");
1794       return self.getParentOperation().getObject();
1795     });
1796     c.def_property_readonly("result_number", [](PyOpResult &self) {
1797       return mlirOpResultGetResultNumber(self.get());
1798     });
1799   }
1800 };
1801 
1802 /// Returns the list of types of the values held by container.
1803 template <typename Container>
1804 static std::vector<PyType> getValueTypes(Container &container,
1805                                          PyMlirContextRef &context) {
1806   std::vector<PyType> result;
1807   result.reserve(container.getNumElements());
1808   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1809     result.push_back(
1810         PyType(context, mlirValueGetType(container.getElement(i).get())));
1811   }
1812   return result;
1813 }
1814 
1815 /// A list of block arguments. Internally, these are stored as consecutive
1816 /// elements, random access is cheap. The argument list is associated with the
1817 /// operation that contains the block (detached blocks are not allowed in
1818 /// Python bindings) and extends its lifetime.
1819 class PyBlockArgumentList
1820     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1821 public:
1822   static constexpr const char *pyClassName = "BlockArgumentList";
1823 
1824   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1825                       intptr_t startIndex = 0, intptr_t length = -1,
1826                       intptr_t step = 1)
1827       : Sliceable(startIndex,
1828                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1829                   step),
1830         operation(std::move(operation)), block(block) {}
1831 
1832   /// Returns the number of arguments in the list.
1833   intptr_t getNumElements() {
1834     operation->checkValid();
1835     return mlirBlockGetNumArguments(block);
1836   }
1837 
1838   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1839   PyBlockArgument getElement(intptr_t pos) {
1840     MlirValue argument = mlirBlockGetArgument(block, pos);
1841     return PyBlockArgument(operation, argument);
1842   }
1843 
1844   /// Returns a sublist of this list.
1845   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1846                             intptr_t step) {
1847     return PyBlockArgumentList(operation, block, startIndex, length, step);
1848   }
1849 
1850   static void bindDerived(ClassTy &c) {
1851     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1852       return getValueTypes(self, self.operation->getContext());
1853     });
1854   }
1855 
1856 private:
1857   PyOperationRef operation;
1858   MlirBlock block;
1859 };
1860 
1861 /// A list of operation operands. Internally, these are stored as consecutive
1862 /// elements, random access is cheap. The result list is associated with the
1863 /// operation whose results these are, and extends the lifetime of this
1864 /// operation.
1865 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1866 public:
1867   static constexpr const char *pyClassName = "OpOperandList";
1868 
1869   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1870                   intptr_t length = -1, intptr_t step = 1)
1871       : Sliceable(startIndex,
1872                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1873                                : length,
1874                   step),
1875         operation(operation) {}
1876 
1877   intptr_t getNumElements() {
1878     operation->checkValid();
1879     return mlirOperationGetNumOperands(operation->get());
1880   }
1881 
1882   PyValue getElement(intptr_t pos) {
1883     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1884     MlirOperation owner;
1885     if (mlirValueIsAOpResult(operand))
1886       owner = mlirOpResultGetOwner(operand);
1887     else if (mlirValueIsABlockArgument(operand))
1888       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1889     else
1890       assert(false && "Value must be an block arg or op result.");
1891     PyOperationRef pyOwner =
1892         PyOperation::forOperation(operation->getContext(), owner);
1893     return PyValue(pyOwner, operand);
1894   }
1895 
1896   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1897     return PyOpOperandList(operation, startIndex, length, step);
1898   }
1899 
1900   void dunderSetItem(intptr_t index, PyValue value) {
1901     index = wrapIndex(index);
1902     mlirOperationSetOperand(operation->get(), index, value.get());
1903   }
1904 
1905   static void bindDerived(ClassTy &c) {
1906     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1907   }
1908 
1909 private:
1910   PyOperationRef operation;
1911 };
1912 
1913 /// A list of operation results. Internally, these are stored as consecutive
1914 /// elements, random access is cheap. The result list is associated with the
1915 /// operation whose results these are, and extends the lifetime of this
1916 /// operation.
1917 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1918 public:
1919   static constexpr const char *pyClassName = "OpResultList";
1920 
1921   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1922                  intptr_t length = -1, intptr_t step = 1)
1923       : Sliceable(startIndex,
1924                   length == -1 ? mlirOperationGetNumResults(operation->get())
1925                                : length,
1926                   step),
1927         operation(operation) {}
1928 
1929   intptr_t getNumElements() {
1930     operation->checkValid();
1931     return mlirOperationGetNumResults(operation->get());
1932   }
1933 
1934   PyOpResult getElement(intptr_t index) {
1935     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1936     return PyOpResult(value);
1937   }
1938 
1939   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1940     return PyOpResultList(operation, startIndex, length, step);
1941   }
1942 
1943   static void bindDerived(ClassTy &c) {
1944     c.def_property_readonly("types", [](PyOpResultList &self) {
1945       return getValueTypes(self, self.operation->getContext());
1946     });
1947   }
1948 
1949 private:
1950   PyOperationRef operation;
1951 };
1952 
1953 /// A list of operation attributes. Can be indexed by name, producing
1954 /// attributes, or by index, producing named attributes.
1955 class PyOpAttributeMap {
1956 public:
1957   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1958 
1959   PyAttribute dunderGetItemNamed(const std::string &name) {
1960     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1961                                                          toMlirStringRef(name));
1962     if (mlirAttributeIsNull(attr)) {
1963       throw SetPyError(PyExc_KeyError,
1964                        "attempt to access a non-existent attribute");
1965     }
1966     return PyAttribute(operation->getContext(), attr);
1967   }
1968 
1969   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1970     if (index < 0 || index >= dunderLen()) {
1971       throw SetPyError(PyExc_IndexError,
1972                        "attempt to access out of bounds attribute");
1973     }
1974     MlirNamedAttribute namedAttr =
1975         mlirOperationGetAttribute(operation->get(), index);
1976     return PyNamedAttribute(
1977         namedAttr.attribute,
1978         std::string(mlirIdentifierStr(namedAttr.name).data,
1979                     mlirIdentifierStr(namedAttr.name).length));
1980   }
1981 
1982   void dunderSetItem(const std::string &name, PyAttribute attr) {
1983     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1984                                     attr);
1985   }
1986 
1987   void dunderDelItem(const std::string &name) {
1988     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1989                                                      toMlirStringRef(name));
1990     if (!removed)
1991       throw SetPyError(PyExc_KeyError,
1992                        "attempt to delete a non-existent attribute");
1993   }
1994 
1995   intptr_t dunderLen() {
1996     return mlirOperationGetNumAttributes(operation->get());
1997   }
1998 
1999   bool dunderContains(const std::string &name) {
2000     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2001         operation->get(), toMlirStringRef(name)));
2002   }
2003 
2004   static void bind(py::module &m) {
2005     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2006         .def("__contains__", &PyOpAttributeMap::dunderContains)
2007         .def("__len__", &PyOpAttributeMap::dunderLen)
2008         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2009         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2010         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2011         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2012   }
2013 
2014 private:
2015   PyOperationRef operation;
2016 };
2017 
2018 } // namespace
2019 
2020 //------------------------------------------------------------------------------
2021 // Populates the core exports of the 'ir' submodule.
2022 //------------------------------------------------------------------------------
2023 
2024 void mlir::python::populateIRCore(py::module &m) {
2025   //----------------------------------------------------------------------------
2026   // Mapping of MlirContext.
2027   //----------------------------------------------------------------------------
2028   py::class_<PyMlirContext>(m, "Context", py::module_local())
2029       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2030       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2031       .def("_get_context_again",
2032            [](PyMlirContext &self) {
2033              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2034              return ref.releaseObject();
2035            })
2036       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2037       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2038       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2039                              &PyMlirContext::getCapsule)
2040       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2041       .def("__enter__", &PyMlirContext::contextEnter)
2042       .def("__exit__", &PyMlirContext::contextExit)
2043       .def_property_readonly_static(
2044           "current",
2045           [](py::object & /*class*/) {
2046             auto *context = PyThreadContextEntry::getDefaultContext();
2047             if (!context)
2048               throw SetPyError(PyExc_ValueError, "No current Context");
2049             return context;
2050           },
2051           "Gets the Context bound to the current thread or raises ValueError")
2052       .def_property_readonly(
2053           "dialects",
2054           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2055           "Gets a container for accessing dialects by name")
2056       .def_property_readonly(
2057           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2058           "Alias for 'dialect'")
2059       .def(
2060           "get_dialect_descriptor",
2061           [=](PyMlirContext &self, std::string &name) {
2062             MlirDialect dialect = mlirContextGetOrLoadDialect(
2063                 self.get(), {name.data(), name.size()});
2064             if (mlirDialectIsNull(dialect)) {
2065               throw SetPyError(PyExc_ValueError,
2066                                Twine("Dialect '") + name + "' not found");
2067             }
2068             return PyDialectDescriptor(self.getRef(), dialect);
2069           },
2070           py::arg("dialect_name"),
2071           "Gets or loads a dialect by name, returning its descriptor object")
2072       .def_property(
2073           "allow_unregistered_dialects",
2074           [](PyMlirContext &self) -> bool {
2075             return mlirContextGetAllowUnregisteredDialects(self.get());
2076           },
2077           [](PyMlirContext &self, bool value) {
2078             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2079           })
2080       .def(
2081           "enable_multithreading",
2082           [](PyMlirContext &self, bool enable) {
2083             mlirContextEnableMultithreading(self.get(), enable);
2084           },
2085           py::arg("enable"))
2086       .def(
2087           "is_registered_operation",
2088           [](PyMlirContext &self, std::string &name) {
2089             return mlirContextIsRegisteredOperation(
2090                 self.get(), MlirStringRef{name.data(), name.size()});
2091           },
2092           py::arg("operation_name"));
2093 
2094   //----------------------------------------------------------------------------
2095   // Mapping of PyDialectDescriptor
2096   //----------------------------------------------------------------------------
2097   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2098       .def_property_readonly("namespace",
2099                              [](PyDialectDescriptor &self) {
2100                                MlirStringRef ns =
2101                                    mlirDialectGetNamespace(self.get());
2102                                return py::str(ns.data, ns.length);
2103                              })
2104       .def("__repr__", [](PyDialectDescriptor &self) {
2105         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2106         std::string repr("<DialectDescriptor ");
2107         repr.append(ns.data, ns.length);
2108         repr.append(">");
2109         return repr;
2110       });
2111 
2112   //----------------------------------------------------------------------------
2113   // Mapping of PyDialects
2114   //----------------------------------------------------------------------------
2115   py::class_<PyDialects>(m, "Dialects", py::module_local())
2116       .def("__getitem__",
2117            [=](PyDialects &self, std::string keyName) {
2118              MlirDialect dialect =
2119                  self.getDialectForKey(keyName, /*attrError=*/false);
2120              py::object descriptor =
2121                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2122              return createCustomDialectWrapper(keyName, std::move(descriptor));
2123            })
2124       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2125         MlirDialect dialect =
2126             self.getDialectForKey(attrName, /*attrError=*/true);
2127         py::object descriptor =
2128             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2129         return createCustomDialectWrapper(attrName, std::move(descriptor));
2130       });
2131 
2132   //----------------------------------------------------------------------------
2133   // Mapping of PyDialect
2134   //----------------------------------------------------------------------------
2135   py::class_<PyDialect>(m, "Dialect", py::module_local())
2136       .def(py::init<py::object>(), py::arg("descriptor"))
2137       .def_property_readonly(
2138           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2139       .def("__repr__", [](py::object self) {
2140         auto clazz = self.attr("__class__");
2141         return py::str("<Dialect ") +
2142                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2143                clazz.attr("__module__") + py::str(".") +
2144                clazz.attr("__name__") + py::str(")>");
2145       });
2146 
2147   //----------------------------------------------------------------------------
2148   // Mapping of Location
2149   //----------------------------------------------------------------------------
2150   py::class_<PyLocation>(m, "Location", py::module_local())
2151       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2152       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2153       .def("__enter__", &PyLocation::contextEnter)
2154       .def("__exit__", &PyLocation::contextExit)
2155       .def("__eq__",
2156            [](PyLocation &self, PyLocation &other) -> bool {
2157              return mlirLocationEqual(self, other);
2158            })
2159       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2160       .def_property_readonly_static(
2161           "current",
2162           [](py::object & /*class*/) {
2163             auto *loc = PyThreadContextEntry::getDefaultLocation();
2164             if (!loc)
2165               throw SetPyError(PyExc_ValueError, "No current Location");
2166             return loc;
2167           },
2168           "Gets the Location bound to the current thread or raises ValueError")
2169       .def_static(
2170           "unknown",
2171           [](DefaultingPyMlirContext context) {
2172             return PyLocation(context->getRef(),
2173                               mlirLocationUnknownGet(context->get()));
2174           },
2175           py::arg("context") = py::none(),
2176           "Gets a Location representing an unknown location")
2177       .def_static(
2178           "callsite",
2179           [](PyLocation callee, const std::vector<PyLocation> &frames,
2180              DefaultingPyMlirContext context) {
2181             if (frames.empty())
2182               throw py::value_error("No caller frames provided");
2183             MlirLocation caller = frames.back().get();
2184             for (const PyLocation &frame :
2185                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2186               caller = mlirLocationCallSiteGet(frame.get(), caller);
2187             return PyLocation(context->getRef(),
2188                               mlirLocationCallSiteGet(callee.get(), caller));
2189           },
2190           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2191           kContextGetCallSiteLocationDocstring)
2192       .def_static(
2193           "file",
2194           [](std::string filename, int line, int col,
2195              DefaultingPyMlirContext context) {
2196             return PyLocation(
2197                 context->getRef(),
2198                 mlirLocationFileLineColGet(
2199                     context->get(), toMlirStringRef(filename), line, col));
2200           },
2201           py::arg("filename"), py::arg("line"), py::arg("col"),
2202           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2203       .def_static(
2204           "fused",
2205           [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
2206              DefaultingPyMlirContext context) {
2207             if (pyLocations.empty())
2208               throw py::value_error("No locations provided");
2209             llvm::SmallVector<MlirLocation, 4> locations;
2210             locations.reserve(pyLocations.size());
2211             for (auto &pyLocation : pyLocations)
2212               locations.push_back(pyLocation.get());
2213             MlirLocation location = mlirLocationFusedGet(
2214                 context->get(), locations.size(), locations.data(),
2215                 metadata ? metadata->get() : MlirAttribute{0});
2216             return PyLocation(context->getRef(), location);
2217           },
2218           py::arg("locations"), py::arg("metadata") = py::none(),
2219           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2220       .def_static(
2221           "name",
2222           [](std::string name, llvm::Optional<PyLocation> childLoc,
2223              DefaultingPyMlirContext context) {
2224             return PyLocation(
2225                 context->getRef(),
2226                 mlirLocationNameGet(
2227                     context->get(), toMlirStringRef(name),
2228                     childLoc ? childLoc->get()
2229                              : mlirLocationUnknownGet(context->get())));
2230           },
2231           py::arg("name"), py::arg("childLoc") = py::none(),
2232           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2233       .def_property_readonly(
2234           "context",
2235           [](PyLocation &self) { return self.getContext().getObject(); },
2236           "Context that owns the Location")
2237       .def("__repr__", [](PyLocation &self) {
2238         PyPrintAccumulator printAccum;
2239         mlirLocationPrint(self, printAccum.getCallback(),
2240                           printAccum.getUserData());
2241         return printAccum.join();
2242       });
2243 
2244   //----------------------------------------------------------------------------
2245   // Mapping of Module
2246   //----------------------------------------------------------------------------
2247   py::class_<PyModule>(m, "Module", py::module_local())
2248       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2249       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2250       .def_static(
2251           "parse",
2252           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2253             MlirModule module = mlirModuleCreateParse(
2254                 context->get(), toMlirStringRef(moduleAsm));
2255             // TODO: Rework error reporting once diagnostic engine is exposed
2256             // in C API.
2257             if (mlirModuleIsNull(module)) {
2258               throw SetPyError(
2259                   PyExc_ValueError,
2260                   "Unable to parse module assembly (see diagnostics)");
2261             }
2262             return PyModule::forModule(module).releaseObject();
2263           },
2264           py::arg("asm"), py::arg("context") = py::none(),
2265           kModuleParseDocstring)
2266       .def_static(
2267           "create",
2268           [](DefaultingPyLocation loc) {
2269             MlirModule module = mlirModuleCreateEmpty(loc);
2270             return PyModule::forModule(module).releaseObject();
2271           },
2272           py::arg("loc") = py::none(), "Creates an empty module")
2273       .def_property_readonly(
2274           "context",
2275           [](PyModule &self) { return self.getContext().getObject(); },
2276           "Context that created the Module")
2277       .def_property_readonly(
2278           "operation",
2279           [](PyModule &self) {
2280             return PyOperation::forOperation(self.getContext(),
2281                                              mlirModuleGetOperation(self.get()),
2282                                              self.getRef().releaseObject())
2283                 .releaseObject();
2284           },
2285           "Accesses the module as an operation")
2286       .def_property_readonly(
2287           "body",
2288           [](PyModule &self) {
2289             PyOperationRef module_op = PyOperation::forOperation(
2290                 self.getContext(), mlirModuleGetOperation(self.get()),
2291                 self.getRef().releaseObject());
2292             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2293             return returnBlock;
2294           },
2295           "Return the block for this module")
2296       .def(
2297           "dump",
2298           [](PyModule &self) {
2299             mlirOperationDump(mlirModuleGetOperation(self.get()));
2300           },
2301           kDumpDocstring)
2302       .def(
2303           "__str__",
2304           [](py::object self) {
2305             // Defer to the operation's __str__.
2306             return self.attr("operation").attr("__str__")();
2307           },
2308           kOperationStrDunderDocstring);
2309 
2310   //----------------------------------------------------------------------------
2311   // Mapping of Operation.
2312   //----------------------------------------------------------------------------
2313   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2314       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2315                              [](PyOperationBase &self) {
2316                                return self.getOperation().getCapsule();
2317                              })
2318       .def("__eq__",
2319            [](PyOperationBase &self, PyOperationBase &other) {
2320              return &self.getOperation() == &other.getOperation();
2321            })
2322       .def("__eq__",
2323            [](PyOperationBase &self, py::object other) { return false; })
2324       .def("__hash__",
2325            [](PyOperationBase &self) {
2326              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2327            })
2328       .def_property_readonly("attributes",
2329                              [](PyOperationBase &self) {
2330                                return PyOpAttributeMap(
2331                                    self.getOperation().getRef());
2332                              })
2333       .def_property_readonly("operands",
2334                              [](PyOperationBase &self) {
2335                                return PyOpOperandList(
2336                                    self.getOperation().getRef());
2337                              })
2338       .def_property_readonly("regions",
2339                              [](PyOperationBase &self) {
2340                                return PyRegionList(
2341                                    self.getOperation().getRef());
2342                              })
2343       .def_property_readonly(
2344           "results",
2345           [](PyOperationBase &self) {
2346             return PyOpResultList(self.getOperation().getRef());
2347           },
2348           "Returns the list of Operation results.")
2349       .def_property_readonly(
2350           "result",
2351           [](PyOperationBase &self) {
2352             auto &operation = self.getOperation();
2353             auto numResults = mlirOperationGetNumResults(operation);
2354             if (numResults != 1) {
2355               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2356               throw SetPyError(
2357                   PyExc_ValueError,
2358                   Twine("Cannot call .result on operation ") +
2359                       StringRef(name.data, name.length) + " which has " +
2360                       Twine(numResults) +
2361                       " results (it is only valid for operations with a "
2362                       "single result)");
2363             }
2364             return PyOpResult(operation.getRef(),
2365                               mlirOperationGetResult(operation, 0));
2366           },
2367           "Shortcut to get an op result if it has only one (throws an error "
2368           "otherwise).")
2369       .def_property_readonly(
2370           "location",
2371           [](PyOperationBase &self) {
2372             PyOperation &operation = self.getOperation();
2373             return PyLocation(operation.getContext(),
2374                               mlirOperationGetLocation(operation.get()));
2375           },
2376           "Returns the source location the operation was defined or derived "
2377           "from.")
2378       .def(
2379           "__str__",
2380           [](PyOperationBase &self) {
2381             return self.getAsm(/*binary=*/false,
2382                                /*largeElementsLimit=*/llvm::None,
2383                                /*enableDebugInfo=*/false,
2384                                /*prettyDebugInfo=*/false,
2385                                /*printGenericOpForm=*/false,
2386                                /*useLocalScope=*/false,
2387                                /*assumeVerified=*/false);
2388           },
2389           "Returns the assembly form of the operation.")
2390       .def("print", &PyOperationBase::print,
2391            // Careful: Lots of arguments must match up with print method.
2392            py::arg("file") = py::none(), py::arg("binary") = false,
2393            py::arg("large_elements_limit") = py::none(),
2394            py::arg("enable_debug_info") = false,
2395            py::arg("pretty_debug_info") = false,
2396            py::arg("print_generic_op_form") = false,
2397            py::arg("use_local_scope") = false,
2398            py::arg("assume_verified") = false, kOperationPrintDocstring)
2399       .def("get_asm", &PyOperationBase::getAsm,
2400            // Careful: Lots of arguments must match up with get_asm method.
2401            py::arg("binary") = false,
2402            py::arg("large_elements_limit") = py::none(),
2403            py::arg("enable_debug_info") = false,
2404            py::arg("pretty_debug_info") = false,
2405            py::arg("print_generic_op_form") = false,
2406            py::arg("use_local_scope") = false,
2407            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2408       .def(
2409           "verify",
2410           [](PyOperationBase &self) {
2411             return mlirOperationVerify(self.getOperation());
2412           },
2413           "Verify the operation and return true if it passes, false if it "
2414           "fails.")
2415       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2416            "Puts self immediately after the other operation in its parent "
2417            "block.")
2418       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2419            "Puts self immediately before the other operation in its parent "
2420            "block.")
2421       .def(
2422           "detach_from_parent",
2423           [](PyOperationBase &self) {
2424             PyOperation &operation = self.getOperation();
2425             operation.checkValid();
2426             if (!operation.isAttached())
2427               throw py::value_error("Detached operation has no parent.");
2428 
2429             operation.detachFromParent();
2430             return operation.createOpView();
2431           },
2432           "Detaches the operation from its parent block.");
2433 
2434   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2435       .def_static("create", &PyOperation::create, py::arg("name"),
2436                   py::arg("results") = py::none(),
2437                   py::arg("operands") = py::none(),
2438                   py::arg("attributes") = py::none(),
2439                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2440                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2441                   kOperationCreateDocstring)
2442       .def_property_readonly("parent",
2443                              [](PyOperation &self) -> py::object {
2444                                auto parent = self.getParentOperation();
2445                                if (parent)
2446                                  return parent->getObject();
2447                                return py::none();
2448                              })
2449       .def("erase", &PyOperation::erase)
2450       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2451                              &PyOperation::getCapsule)
2452       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2453       .def_property_readonly("name",
2454                              [](PyOperation &self) {
2455                                self.checkValid();
2456                                MlirOperation operation = self.get();
2457                                MlirStringRef name = mlirIdentifierStr(
2458                                    mlirOperationGetName(operation));
2459                                return py::str(name.data, name.length);
2460                              })
2461       .def_property_readonly(
2462           "context",
2463           [](PyOperation &self) {
2464             self.checkValid();
2465             return self.getContext().getObject();
2466           },
2467           "Context that owns the Operation")
2468       .def_property_readonly("opview", &PyOperation::createOpView);
2469 
2470   auto opViewClass =
2471       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2472           .def(py::init<py::object>(), py::arg("operation"))
2473           .def_property_readonly("operation", &PyOpView::getOperationObject)
2474           .def_property_readonly(
2475               "context",
2476               [](PyOpView &self) {
2477                 return self.getOperation().getContext().getObject();
2478               },
2479               "Context that owns the Operation")
2480           .def("__str__", [](PyOpView &self) {
2481             return py::str(self.getOperationObject());
2482           });
2483   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2484   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2485   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2486   opViewClass.attr("build_generic") = classmethod(
2487       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2488       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2489       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2490       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2491       "Builds a specific, generated OpView based on class level attributes.");
2492 
2493   //----------------------------------------------------------------------------
2494   // Mapping of PyRegion.
2495   //----------------------------------------------------------------------------
2496   py::class_<PyRegion>(m, "Region", py::module_local())
2497       .def_property_readonly(
2498           "blocks",
2499           [](PyRegion &self) {
2500             return PyBlockList(self.getParentOperation(), self.get());
2501           },
2502           "Returns a forward-optimized sequence of blocks.")
2503       .def_property_readonly(
2504           "owner",
2505           [](PyRegion &self) {
2506             return self.getParentOperation()->createOpView();
2507           },
2508           "Returns the operation owning this region.")
2509       .def(
2510           "__iter__",
2511           [](PyRegion &self) {
2512             self.checkValid();
2513             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2514             return PyBlockIterator(self.getParentOperation(), firstBlock);
2515           },
2516           "Iterates over blocks in the region.")
2517       .def("__eq__",
2518            [](PyRegion &self, PyRegion &other) {
2519              return self.get().ptr == other.get().ptr;
2520            })
2521       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2522 
2523   //----------------------------------------------------------------------------
2524   // Mapping of PyBlock.
2525   //----------------------------------------------------------------------------
2526   py::class_<PyBlock>(m, "Block", py::module_local())
2527       .def_property_readonly(
2528           "owner",
2529           [](PyBlock &self) {
2530             return self.getParentOperation()->createOpView();
2531           },
2532           "Returns the owning operation of this block.")
2533       .def_property_readonly(
2534           "region",
2535           [](PyBlock &self) {
2536             MlirRegion region = mlirBlockGetParentRegion(self.get());
2537             return PyRegion(self.getParentOperation(), region);
2538           },
2539           "Returns the owning region of this block.")
2540       .def_property_readonly(
2541           "arguments",
2542           [](PyBlock &self) {
2543             return PyBlockArgumentList(self.getParentOperation(), self.get());
2544           },
2545           "Returns a list of block arguments.")
2546       .def_property_readonly(
2547           "operations",
2548           [](PyBlock &self) {
2549             return PyOperationList(self.getParentOperation(), self.get());
2550           },
2551           "Returns a forward-optimized sequence of operations.")
2552       .def_static(
2553           "create_at_start",
2554           [](PyRegion &parent, py::list pyArgTypes) {
2555             parent.checkValid();
2556             llvm::SmallVector<MlirType, 4> argTypes;
2557             argTypes.reserve(pyArgTypes.size());
2558             for (auto &pyArg : pyArgTypes) {
2559               argTypes.push_back(pyArg.cast<PyType &>());
2560             }
2561 
2562             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2563             mlirRegionInsertOwnedBlock(parent, 0, block);
2564             return PyBlock(parent.getParentOperation(), block);
2565           },
2566           py::arg("parent"), py::arg("arg_types") = py::list(),
2567           "Creates and returns a new Block at the beginning of the given "
2568           "region (with given argument types).")
2569       .def(
2570           "create_before",
2571           [](PyBlock &self, py::args pyArgTypes) {
2572             self.checkValid();
2573             llvm::SmallVector<MlirType, 4> argTypes;
2574             argTypes.reserve(pyArgTypes.size());
2575             for (auto &pyArg : pyArgTypes) {
2576               argTypes.push_back(pyArg.cast<PyType &>());
2577             }
2578 
2579             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2580             MlirRegion region = mlirBlockGetParentRegion(self.get());
2581             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2582             return PyBlock(self.getParentOperation(), block);
2583           },
2584           "Creates and returns a new Block before this block "
2585           "(with given argument types).")
2586       .def(
2587           "create_after",
2588           [](PyBlock &self, py::args pyArgTypes) {
2589             self.checkValid();
2590             llvm::SmallVector<MlirType, 4> argTypes;
2591             argTypes.reserve(pyArgTypes.size());
2592             for (auto &pyArg : pyArgTypes) {
2593               argTypes.push_back(pyArg.cast<PyType &>());
2594             }
2595 
2596             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2597             MlirRegion region = mlirBlockGetParentRegion(self.get());
2598             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2599             return PyBlock(self.getParentOperation(), block);
2600           },
2601           "Creates and returns a new Block after this block "
2602           "(with given argument types).")
2603       .def(
2604           "__iter__",
2605           [](PyBlock &self) {
2606             self.checkValid();
2607             MlirOperation firstOperation =
2608                 mlirBlockGetFirstOperation(self.get());
2609             return PyOperationIterator(self.getParentOperation(),
2610                                        firstOperation);
2611           },
2612           "Iterates over operations in the block.")
2613       .def("__eq__",
2614            [](PyBlock &self, PyBlock &other) {
2615              return self.get().ptr == other.get().ptr;
2616            })
2617       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2618       .def(
2619           "__str__",
2620           [](PyBlock &self) {
2621             self.checkValid();
2622             PyPrintAccumulator printAccum;
2623             mlirBlockPrint(self.get(), printAccum.getCallback(),
2624                            printAccum.getUserData());
2625             return printAccum.join();
2626           },
2627           "Returns the assembly form of the block.")
2628       .def(
2629           "append",
2630           [](PyBlock &self, PyOperationBase &operation) {
2631             if (operation.getOperation().isAttached())
2632               operation.getOperation().detachFromParent();
2633 
2634             MlirOperation mlirOperation = operation.getOperation().get();
2635             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2636             operation.getOperation().setAttached(
2637                 self.getParentOperation().getObject());
2638           },
2639           py::arg("operation"),
2640           "Appends an operation to this block. If the operation is currently "
2641           "in another block, it will be moved.");
2642 
2643   //----------------------------------------------------------------------------
2644   // Mapping of PyInsertionPoint.
2645   //----------------------------------------------------------------------------
2646 
2647   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2648       .def(py::init<PyBlock &>(), py::arg("block"),
2649            "Inserts after the last operation but still inside the block.")
2650       .def("__enter__", &PyInsertionPoint::contextEnter)
2651       .def("__exit__", &PyInsertionPoint::contextExit)
2652       .def_property_readonly_static(
2653           "current",
2654           [](py::object & /*class*/) {
2655             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2656             if (!ip)
2657               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2658             return ip;
2659           },
2660           "Gets the InsertionPoint bound to the current thread or raises "
2661           "ValueError if none has been set")
2662       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2663            "Inserts before a referenced operation.")
2664       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2665                   py::arg("block"), "Inserts at the beginning of the block.")
2666       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2667                   py::arg("block"), "Inserts before the block terminator.")
2668       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2669            "Inserts an operation.")
2670       .def_property_readonly(
2671           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2672           "Returns the block that this InsertionPoint points to.");
2673 
2674   //----------------------------------------------------------------------------
2675   // Mapping of PyAttribute.
2676   //----------------------------------------------------------------------------
2677   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2678       // Delegate to the PyAttribute copy constructor, which will also lifetime
2679       // extend the backing context which owns the MlirAttribute.
2680       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2681            "Casts the passed attribute to the generic Attribute")
2682       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2683                              &PyAttribute::getCapsule)
2684       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2685       .def_static(
2686           "parse",
2687           [](std::string attrSpec, DefaultingPyMlirContext context) {
2688             MlirAttribute type = mlirAttributeParseGet(
2689                 context->get(), toMlirStringRef(attrSpec));
2690             // TODO: Rework error reporting once diagnostic engine is exposed
2691             // in C API.
2692             if (mlirAttributeIsNull(type)) {
2693               throw SetPyError(PyExc_ValueError,
2694                                Twine("Unable to parse attribute: '") +
2695                                    attrSpec + "'");
2696             }
2697             return PyAttribute(context->getRef(), type);
2698           },
2699           py::arg("asm"), py::arg("context") = py::none(),
2700           "Parses an attribute from an assembly form")
2701       .def_property_readonly(
2702           "context",
2703           [](PyAttribute &self) { return self.getContext().getObject(); },
2704           "Context that owns the Attribute")
2705       .def_property_readonly("type",
2706                              [](PyAttribute &self) {
2707                                return PyType(self.getContext()->getRef(),
2708                                              mlirAttributeGetType(self));
2709                              })
2710       .def(
2711           "get_named",
2712           [](PyAttribute &self, std::string name) {
2713             return PyNamedAttribute(self, std::move(name));
2714           },
2715           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2716       .def("__eq__",
2717            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2718       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2719       .def("__hash__",
2720            [](PyAttribute &self) {
2721              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2722            })
2723       .def(
2724           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2725           kDumpDocstring)
2726       .def(
2727           "__str__",
2728           [](PyAttribute &self) {
2729             PyPrintAccumulator printAccum;
2730             mlirAttributePrint(self, printAccum.getCallback(),
2731                                printAccum.getUserData());
2732             return printAccum.join();
2733           },
2734           "Returns the assembly form of the Attribute.")
2735       .def("__repr__", [](PyAttribute &self) {
2736         // Generally, assembly formats are not printed for __repr__ because
2737         // this can cause exceptionally long debug output and exceptions.
2738         // However, attribute values are generally considered useful and are
2739         // printed. This may need to be re-evaluated if debug dumps end up
2740         // being excessive.
2741         PyPrintAccumulator printAccum;
2742         printAccum.parts.append("Attribute(");
2743         mlirAttributePrint(self, printAccum.getCallback(),
2744                            printAccum.getUserData());
2745         printAccum.parts.append(")");
2746         return printAccum.join();
2747       });
2748 
2749   //----------------------------------------------------------------------------
2750   // Mapping of PyNamedAttribute
2751   //----------------------------------------------------------------------------
2752   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2753       .def("__repr__",
2754            [](PyNamedAttribute &self) {
2755              PyPrintAccumulator printAccum;
2756              printAccum.parts.append("NamedAttribute(");
2757              printAccum.parts.append(
2758                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2759                          mlirIdentifierStr(self.namedAttr.name).length));
2760              printAccum.parts.append("=");
2761              mlirAttributePrint(self.namedAttr.attribute,
2762                                 printAccum.getCallback(),
2763                                 printAccum.getUserData());
2764              printAccum.parts.append(")");
2765              return printAccum.join();
2766            })
2767       .def_property_readonly(
2768           "name",
2769           [](PyNamedAttribute &self) {
2770             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2771                            mlirIdentifierStr(self.namedAttr.name).length);
2772           },
2773           "The name of the NamedAttribute binding")
2774       .def_property_readonly(
2775           "attr",
2776           [](PyNamedAttribute &self) {
2777             // TODO: When named attribute is removed/refactored, also remove
2778             // this constructor (it does an inefficient table lookup).
2779             auto contextRef = PyMlirContext::forContext(
2780                 mlirAttributeGetContext(self.namedAttr.attribute));
2781             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2782           },
2783           py::keep_alive<0, 1>(),
2784           "The underlying generic attribute of the NamedAttribute binding");
2785 
2786   //----------------------------------------------------------------------------
2787   // Mapping of PyType.
2788   //----------------------------------------------------------------------------
2789   py::class_<PyType>(m, "Type", py::module_local())
2790       // Delegate to the PyType copy constructor, which will also lifetime
2791       // extend the backing context which owns the MlirType.
2792       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2793            "Casts the passed type to the generic Type")
2794       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2795       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2796       .def_static(
2797           "parse",
2798           [](std::string typeSpec, DefaultingPyMlirContext context) {
2799             MlirType type =
2800                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2801             // TODO: Rework error reporting once diagnostic engine is exposed
2802             // in C API.
2803             if (mlirTypeIsNull(type)) {
2804               throw SetPyError(PyExc_ValueError,
2805                                Twine("Unable to parse type: '") + typeSpec +
2806                                    "'");
2807             }
2808             return PyType(context->getRef(), type);
2809           },
2810           py::arg("asm"), py::arg("context") = py::none(),
2811           kContextParseTypeDocstring)
2812       .def_property_readonly(
2813           "context", [](PyType &self) { return self.getContext().getObject(); },
2814           "Context that owns the Type")
2815       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2816       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2817       .def("__hash__",
2818            [](PyType &self) {
2819              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2820            })
2821       .def(
2822           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2823       .def(
2824           "__str__",
2825           [](PyType &self) {
2826             PyPrintAccumulator printAccum;
2827             mlirTypePrint(self, printAccum.getCallback(),
2828                           printAccum.getUserData());
2829             return printAccum.join();
2830           },
2831           "Returns the assembly form of the type.")
2832       .def("__repr__", [](PyType &self) {
2833         // Generally, assembly formats are not printed for __repr__ because
2834         // this can cause exceptionally long debug output and exceptions.
2835         // However, types are an exception as they typically have compact
2836         // assembly forms and printing them is useful.
2837         PyPrintAccumulator printAccum;
2838         printAccum.parts.append("Type(");
2839         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2840         printAccum.parts.append(")");
2841         return printAccum.join();
2842       });
2843 
2844   //----------------------------------------------------------------------------
2845   // Mapping of Value.
2846   //----------------------------------------------------------------------------
2847   py::class_<PyValue>(m, "Value", py::module_local())
2848       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2849       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2850       .def_property_readonly(
2851           "context",
2852           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2853           "Context in which the value lives.")
2854       .def(
2855           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2856           kDumpDocstring)
2857       .def_property_readonly(
2858           "owner",
2859           [](PyValue &self) {
2860             assert(mlirOperationEqual(self.getParentOperation()->get(),
2861                                       mlirOpResultGetOwner(self.get())) &&
2862                    "expected the owner of the value in Python to match that in "
2863                    "the IR");
2864             return self.getParentOperation().getObject();
2865           })
2866       .def("__eq__",
2867            [](PyValue &self, PyValue &other) {
2868              return self.get().ptr == other.get().ptr;
2869            })
2870       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2871       .def("__hash__",
2872            [](PyValue &self) {
2873              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2874            })
2875       .def(
2876           "__str__",
2877           [](PyValue &self) {
2878             PyPrintAccumulator printAccum;
2879             printAccum.parts.append("Value(");
2880             mlirValuePrint(self.get(), printAccum.getCallback(),
2881                            printAccum.getUserData());
2882             printAccum.parts.append(")");
2883             return printAccum.join();
2884           },
2885           kValueDunderStrDocstring)
2886       .def_property_readonly("type", [](PyValue &self) {
2887         return PyType(self.getParentOperation()->getContext(),
2888                       mlirValueGetType(self.get()));
2889       });
2890   PyBlockArgument::bind(m);
2891   PyOpResult::bind(m);
2892 
2893   //----------------------------------------------------------------------------
2894   // Mapping of SymbolTable.
2895   //----------------------------------------------------------------------------
2896   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2897       .def(py::init<PyOperationBase &>())
2898       .def("__getitem__", &PySymbolTable::dunderGetItem)
2899       .def("insert", &PySymbolTable::insert, py::arg("operation"))
2900       .def("erase", &PySymbolTable::erase, py::arg("operation"))
2901       .def("__delitem__", &PySymbolTable::dunderDel)
2902       .def("__contains__",
2903            [](PySymbolTable &table, const std::string &name) {
2904              return !mlirOperationIsNull(mlirSymbolTableLookup(
2905                  table, mlirStringRefCreate(name.data(), name.length())));
2906            })
2907       // Static helpers.
2908       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
2909                   py::arg("symbol"), py::arg("name"))
2910       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
2911                   py::arg("symbol"))
2912       .def_static("get_visibility", &PySymbolTable::getVisibility,
2913                   py::arg("symbol"))
2914       .def_static("set_visibility", &PySymbolTable::setVisibility,
2915                   py::arg("symbol"), py::arg("visibility"))
2916       .def_static("replace_all_symbol_uses",
2917                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
2918                   py::arg("new_symbol"), py::arg("from_op"))
2919       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
2920                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
2921                   py::arg("callback"));
2922 
2923   // Container bindings.
2924   PyBlockArgumentList::bind(m);
2925   PyBlockIterator::bind(m);
2926   PyBlockList::bind(m);
2927   PyOperationIterator::bind(m);
2928   PyOperationList::bind(m);
2929   PyOpAttributeMap::bind(m);
2930   PyOpOperandList::bind(m);
2931   PyOpResultList::bind(m);
2932   PyRegionIterator::bind(m);
2933   PyRegionList::bind(m);
2934 
2935   // Debug bindings.
2936   PyGlobalDebugFlag::bind(m);
2937 }
2938