1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96   assume_verified: By default, if not printing generic form, the verifier
97     will be run and if it fails, generic form will be printed with a comment
98     about failed verification. While a reasonable default for interactive use,
99     for systematic use, it is often better for the caller to verify explicitly
100     and report failures in a more robust fashion. Set this to True if doing this
101     in order to avoid running a redundant verification. If the IR is actually
102     invalid, behavior is undefined.
103 )";
104 
105 static const char kOperationGetAsmDocstring[] =
106     R"(Gets the assembly form of the operation with all options available.
107 
108 Args:
109   binary: Whether to return a bytes (True) or str (False) object. Defaults to
110     False.
111   ... others ...: See the print() method for common keyword arguments for
112     configuring the printout.
113 Returns:
114   Either a bytes or str object, depending on the setting of the 'binary'
115   argument.
116 )";
117 
118 static const char kOperationStrDunderDocstring[] =
119     R"(Gets the assembly form of the operation with default options.
120 
121 If more advanced control over the assembly formatting or I/O options is needed,
122 use the dedicated print or get_asm method, which supports keyword arguments to
123 customize behavior.
124 )";
125 
126 static const char kDumpDocstring[] =
127     R"(Dumps a debug representation of the object to stderr.)";
128 
129 static const char kAppendBlockDocstring[] =
130     R"(Appends a new block, with argument types as positional args.
131 
132 Returns:
133   The created block.
134 )";
135 
136 static const char kValueDunderStrDocstring[] =
137     R"(Returns the string form of the value.
138 
139 If the value is a block argument, this is the assembly form of its type and the
140 position in the argument list. If the value is an operation result, this is
141 equivalent to printing the operation that produced it.
142 )";
143 
144 //------------------------------------------------------------------------------
145 // Utilities.
146 //------------------------------------------------------------------------------
147 
148 /// Helper for creating an @classmethod.
149 template <class Func, typename... Args>
150 py::object classmethod(Func f, Args... args) {
151   py::object cf = py::cpp_function(f, args...);
152   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
153 }
154 
155 static py::object
156 createCustomDialectWrapper(const std::string &dialectNamespace,
157                            py::object dialectDescriptor) {
158   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
159   if (!dialectClass) {
160     // Use the base class.
161     return py::cast(PyDialect(std::move(dialectDescriptor)));
162   }
163 
164   // Create the custom implementation.
165   return (*dialectClass)(std::move(dialectDescriptor));
166 }
167 
168 static MlirStringRef toMlirStringRef(const std::string &s) {
169   return mlirStringRefCreate(s.data(), s.size());
170 }
171 
172 /// Wrapper for the global LLVM debugging flag.
173 struct PyGlobalDebugFlag {
174   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
175 
176   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
177 
178   static void bind(py::module &m) {
179     // Debug flags.
180     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
181         .def_property_static("flag", &PyGlobalDebugFlag::get,
182                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
183   }
184 };
185 
186 //------------------------------------------------------------------------------
187 // Collections.
188 //------------------------------------------------------------------------------
189 
190 namespace {
191 
192 class PyRegionIterator {
193 public:
194   PyRegionIterator(PyOperationRef operation)
195       : operation(std::move(operation)) {}
196 
197   PyRegionIterator &dunderIter() { return *this; }
198 
199   PyRegion dunderNext() {
200     operation->checkValid();
201     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
202       throw py::stop_iteration();
203     }
204     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
205     return PyRegion(operation, region);
206   }
207 
208   static void bind(py::module &m) {
209     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
210         .def("__iter__", &PyRegionIterator::dunderIter)
211         .def("__next__", &PyRegionIterator::dunderNext);
212   }
213 
214 private:
215   PyOperationRef operation;
216   int nextIndex = 0;
217 };
218 
219 /// Regions of an op are fixed length and indexed numerically so are represented
220 /// with a sequence-like container.
221 class PyRegionList {
222 public:
223   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
224 
225   intptr_t dunderLen() {
226     operation->checkValid();
227     return mlirOperationGetNumRegions(operation->get());
228   }
229 
230   PyRegion dunderGetItem(intptr_t index) {
231     // dunderLen checks validity.
232     if (index < 0 || index >= dunderLen()) {
233       throw SetPyError(PyExc_IndexError,
234                        "attempt to access out of bounds region");
235     }
236     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
237     return PyRegion(operation, region);
238   }
239 
240   static void bind(py::module &m) {
241     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
242         .def("__len__", &PyRegionList::dunderLen)
243         .def("__getitem__", &PyRegionList::dunderGetItem);
244   }
245 
246 private:
247   PyOperationRef operation;
248 };
249 
250 class PyBlockIterator {
251 public:
252   PyBlockIterator(PyOperationRef operation, MlirBlock next)
253       : operation(std::move(operation)), next(next) {}
254 
255   PyBlockIterator &dunderIter() { return *this; }
256 
257   PyBlock dunderNext() {
258     operation->checkValid();
259     if (mlirBlockIsNull(next)) {
260       throw py::stop_iteration();
261     }
262 
263     PyBlock returnBlock(operation, next);
264     next = mlirBlockGetNextInRegion(next);
265     return returnBlock;
266   }
267 
268   static void bind(py::module &m) {
269     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
270         .def("__iter__", &PyBlockIterator::dunderIter)
271         .def("__next__", &PyBlockIterator::dunderNext);
272   }
273 
274 private:
275   PyOperationRef operation;
276   MlirBlock next;
277 };
278 
279 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
280 /// we present them as a more full-featured list-like container but optimize
281 /// it for forward iteration. Blocks are always owned by a region.
282 class PyBlockList {
283 public:
284   PyBlockList(PyOperationRef operation, MlirRegion region)
285       : operation(std::move(operation)), region(region) {}
286 
287   PyBlockIterator dunderIter() {
288     operation->checkValid();
289     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
290   }
291 
292   intptr_t dunderLen() {
293     operation->checkValid();
294     intptr_t count = 0;
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       count += 1;
298       block = mlirBlockGetNextInRegion(block);
299     }
300     return count;
301   }
302 
303   PyBlock dunderGetItem(intptr_t index) {
304     operation->checkValid();
305     if (index < 0) {
306       throw SetPyError(PyExc_IndexError,
307                        "attempt to access out of bounds block");
308     }
309     MlirBlock block = mlirRegionGetFirstBlock(region);
310     while (!mlirBlockIsNull(block)) {
311       if (index == 0) {
312         return PyBlock(operation, block);
313       }
314       block = mlirBlockGetNextInRegion(block);
315       index -= 1;
316     }
317     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
318   }
319 
320   PyBlock appendBlock(py::args pyArgTypes) {
321     operation->checkValid();
322     llvm::SmallVector<MlirType, 4> argTypes;
323     argTypes.reserve(pyArgTypes.size());
324     for (auto &pyArg : pyArgTypes) {
325       argTypes.push_back(pyArg.cast<PyType &>());
326     }
327 
328     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
329     mlirRegionAppendOwnedBlock(region, block);
330     return PyBlock(operation, block);
331   }
332 
333   static void bind(py::module &m) {
334     py::class_<PyBlockList>(m, "BlockList", py::module_local())
335         .def("__getitem__", &PyBlockList::dunderGetItem)
336         .def("__iter__", &PyBlockList::dunderIter)
337         .def("__len__", &PyBlockList::dunderLen)
338         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
339   }
340 
341 private:
342   PyOperationRef operation;
343   MlirRegion region;
344 };
345 
346 class PyOperationIterator {
347 public:
348   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
349       : parentOperation(std::move(parentOperation)), next(next) {}
350 
351   PyOperationIterator &dunderIter() { return *this; }
352 
353   py::object dunderNext() {
354     parentOperation->checkValid();
355     if (mlirOperationIsNull(next)) {
356       throw py::stop_iteration();
357     }
358 
359     PyOperationRef returnOperation =
360         PyOperation::forOperation(parentOperation->getContext(), next);
361     next = mlirOperationGetNextInBlock(next);
362     return returnOperation->createOpView();
363   }
364 
365   static void bind(py::module &m) {
366     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
367         .def("__iter__", &PyOperationIterator::dunderIter)
368         .def("__next__", &PyOperationIterator::dunderNext);
369   }
370 
371 private:
372   PyOperationRef parentOperation;
373   MlirOperation next;
374 };
375 
376 /// Operations are exposed by the C-API as a forward-only linked list. In
377 /// Python, we present them as a more full-featured list-like container but
378 /// optimize it for forward iteration. Iterable operations are always owned
379 /// by a block.
380 class PyOperationList {
381 public:
382   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
383       : parentOperation(std::move(parentOperation)), block(block) {}
384 
385   PyOperationIterator dunderIter() {
386     parentOperation->checkValid();
387     return PyOperationIterator(parentOperation,
388                                mlirBlockGetFirstOperation(block));
389   }
390 
391   intptr_t dunderLen() {
392     parentOperation->checkValid();
393     intptr_t count = 0;
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       count += 1;
397       childOp = mlirOperationGetNextInBlock(childOp);
398     }
399     return count;
400   }
401 
402   py::object dunderGetItem(intptr_t index) {
403     parentOperation->checkValid();
404     if (index < 0) {
405       throw SetPyError(PyExc_IndexError,
406                        "attempt to access out of bounds operation");
407     }
408     MlirOperation childOp = mlirBlockGetFirstOperation(block);
409     while (!mlirOperationIsNull(childOp)) {
410       if (index == 0) {
411         return PyOperation::forOperation(parentOperation->getContext(), childOp)
412             ->createOpView();
413       }
414       childOp = mlirOperationGetNextInBlock(childOp);
415       index -= 1;
416     }
417     throw SetPyError(PyExc_IndexError,
418                      "attempt to access out of bounds operation");
419   }
420 
421   static void bind(py::module &m) {
422     py::class_<PyOperationList>(m, "OperationList", py::module_local())
423         .def("__getitem__", &PyOperationList::dunderGetItem)
424         .def("__iter__", &PyOperationList::dunderIter)
425         .def("__len__", &PyOperationList::dunderLen);
426   }
427 
428 private:
429   PyOperationRef parentOperation;
430   MlirBlock block;
431 };
432 
433 } // namespace
434 
435 //------------------------------------------------------------------------------
436 // PyMlirContext
437 //------------------------------------------------------------------------------
438 
439 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
440   py::gil_scoped_acquire acquire;
441   auto &liveContexts = getLiveContexts();
442   liveContexts[context.ptr] = this;
443 }
444 
445 PyMlirContext::~PyMlirContext() {
446   // Note that the only public way to construct an instance is via the
447   // forContext method, which always puts the associated handle into
448   // liveContexts.
449   py::gil_scoped_acquire acquire;
450   getLiveContexts().erase(context.ptr);
451   mlirContextDestroy(context);
452 }
453 
454 py::object PyMlirContext::getCapsule() {
455   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
456 }
457 
458 py::object PyMlirContext::createFromCapsule(py::object capsule) {
459   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
460   if (mlirContextIsNull(rawContext))
461     throw py::error_already_set();
462   return forContext(rawContext).releaseObject();
463 }
464 
465 PyMlirContext *PyMlirContext::createNewContextForInit() {
466   MlirContext context = mlirContextCreate();
467   mlirRegisterAllDialects(context);
468   return new PyMlirContext(context);
469 }
470 
471 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
472   py::gil_scoped_acquire acquire;
473   auto &liveContexts = getLiveContexts();
474   auto it = liveContexts.find(context.ptr);
475   if (it == liveContexts.end()) {
476     // Create.
477     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
478     py::object pyRef = py::cast(unownedContextWrapper);
479     assert(pyRef && "cast to py::object failed");
480     liveContexts[context.ptr] = unownedContextWrapper;
481     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
482   }
483   // Use existing.
484   py::object pyRef = py::cast(it->second);
485   return PyMlirContextRef(it->second, std::move(pyRef));
486 }
487 
488 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
489   static LiveContextMap liveContexts;
490   return liveContexts;
491 }
492 
493 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
494 
495 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
496 
497 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
498 
499 pybind11::object PyMlirContext::contextEnter() {
500   return PyThreadContextEntry::pushContext(*this);
501 }
502 
503 void PyMlirContext::contextExit(pybind11::object excType,
504                                 pybind11::object excVal,
505                                 pybind11::object excTb) {
506   PyThreadContextEntry::popContext(*this);
507 }
508 
509 PyMlirContext &DefaultingPyMlirContext::resolve() {
510   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
511   if (!context) {
512     throw SetPyError(
513         PyExc_RuntimeError,
514         "An MLIR function requires a Context but none was provided in the call "
515         "or from the surrounding environment. Either pass to the function with "
516         "a 'context=' argument or establish a default using 'with Context():'");
517   }
518   return *context;
519 }
520 
521 //------------------------------------------------------------------------------
522 // PyThreadContextEntry management
523 //------------------------------------------------------------------------------
524 
525 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
526   static thread_local std::vector<PyThreadContextEntry> stack;
527   return stack;
528 }
529 
530 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
531   auto &stack = getStack();
532   if (stack.empty())
533     return nullptr;
534   return &stack.back();
535 }
536 
537 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
538                                 py::object insertionPoint,
539                                 py::object location) {
540   auto &stack = getStack();
541   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
542                      std::move(location));
543   // If the new stack has more than one entry and the context of the new top
544   // entry matches the previous, copy the insertionPoint and location from the
545   // previous entry if missing from the new top entry.
546   if (stack.size() > 1) {
547     auto &prev = *(stack.rbegin() + 1);
548     auto &current = stack.back();
549     if (current.context.is(prev.context)) {
550       // Default non-context objects from the previous entry.
551       if (!current.insertionPoint)
552         current.insertionPoint = prev.insertionPoint;
553       if (!current.location)
554         current.location = prev.location;
555     }
556   }
557 }
558 
559 PyMlirContext *PyThreadContextEntry::getContext() {
560   if (!context)
561     return nullptr;
562   return py::cast<PyMlirContext *>(context);
563 }
564 
565 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
566   if (!insertionPoint)
567     return nullptr;
568   return py::cast<PyInsertionPoint *>(insertionPoint);
569 }
570 
571 PyLocation *PyThreadContextEntry::getLocation() {
572   if (!location)
573     return nullptr;
574   return py::cast<PyLocation *>(location);
575 }
576 
577 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
578   auto *tos = getTopOfStack();
579   return tos ? tos->getContext() : nullptr;
580 }
581 
582 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
583   auto *tos = getTopOfStack();
584   return tos ? tos->getInsertionPoint() : nullptr;
585 }
586 
587 PyLocation *PyThreadContextEntry::getDefaultLocation() {
588   auto *tos = getTopOfStack();
589   return tos ? tos->getLocation() : nullptr;
590 }
591 
592 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
593   py::object contextObj = py::cast(context);
594   push(FrameKind::Context, /*context=*/contextObj,
595        /*insertionPoint=*/py::object(),
596        /*location=*/py::object());
597   return contextObj;
598 }
599 
600 void PyThreadContextEntry::popContext(PyMlirContext &context) {
601   auto &stack = getStack();
602   if (stack.empty())
603     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
604   auto &tos = stack.back();
605   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
606     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
607   stack.pop_back();
608 }
609 
610 py::object
611 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
612   py::object contextObj =
613       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
614   py::object insertionPointObj = py::cast(insertionPoint);
615   push(FrameKind::InsertionPoint,
616        /*context=*/contextObj,
617        /*insertionPoint=*/insertionPointObj,
618        /*location=*/py::object());
619   return insertionPointObj;
620 }
621 
622 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
623   auto &stack = getStack();
624   if (stack.empty())
625     throw SetPyError(PyExc_RuntimeError,
626                      "Unbalanced InsertionPoint enter/exit");
627   auto &tos = stack.back();
628   if (tos.frameKind != FrameKind::InsertionPoint &&
629       tos.getInsertionPoint() != &insertionPoint)
630     throw SetPyError(PyExc_RuntimeError,
631                      "Unbalanced InsertionPoint enter/exit");
632   stack.pop_back();
633 }
634 
635 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
636   py::object contextObj = location.getContext().getObject();
637   py::object locationObj = py::cast(location);
638   push(FrameKind::Location, /*context=*/contextObj,
639        /*insertionPoint=*/py::object(),
640        /*location=*/locationObj);
641   return locationObj;
642 }
643 
644 void PyThreadContextEntry::popLocation(PyLocation &location) {
645   auto &stack = getStack();
646   if (stack.empty())
647     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
648   auto &tos = stack.back();
649   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
650     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
651   stack.pop_back();
652 }
653 
654 //------------------------------------------------------------------------------
655 // PyDialect, PyDialectDescriptor, PyDialects
656 //------------------------------------------------------------------------------
657 
658 MlirDialect PyDialects::getDialectForKey(const std::string &key,
659                                          bool attrError) {
660   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
661                                                     {key.data(), key.size()});
662   if (mlirDialectIsNull(dialect)) {
663     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
664                      Twine("Dialect '") + key + "' not found");
665   }
666   return dialect;
667 }
668 
669 //------------------------------------------------------------------------------
670 // PyLocation
671 //------------------------------------------------------------------------------
672 
673 py::object PyLocation::getCapsule() {
674   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
675 }
676 
677 PyLocation PyLocation::createFromCapsule(py::object capsule) {
678   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
679   if (mlirLocationIsNull(rawLoc))
680     throw py::error_already_set();
681   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
682                     rawLoc);
683 }
684 
685 py::object PyLocation::contextEnter() {
686   return PyThreadContextEntry::pushLocation(*this);
687 }
688 
689 void PyLocation::contextExit(py::object excType, py::object excVal,
690                              py::object excTb) {
691   PyThreadContextEntry::popLocation(*this);
692 }
693 
694 PyLocation &DefaultingPyLocation::resolve() {
695   auto *location = PyThreadContextEntry::getDefaultLocation();
696   if (!location) {
697     throw SetPyError(
698         PyExc_RuntimeError,
699         "An MLIR function requires a Location but none was provided in the "
700         "call or from the surrounding environment. Either pass to the function "
701         "with a 'loc=' argument or establish a default using 'with loc:'");
702   }
703   return *location;
704 }
705 
706 //------------------------------------------------------------------------------
707 // PyModule
708 //------------------------------------------------------------------------------
709 
710 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
711     : BaseContextObject(std::move(contextRef)), module(module) {}
712 
713 PyModule::~PyModule() {
714   py::gil_scoped_acquire acquire;
715   auto &liveModules = getContext()->liveModules;
716   assert(liveModules.count(module.ptr) == 1 &&
717          "destroying module not in live map");
718   liveModules.erase(module.ptr);
719   mlirModuleDestroy(module);
720 }
721 
722 PyModuleRef PyModule::forModule(MlirModule module) {
723   MlirContext context = mlirModuleGetContext(module);
724   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
725 
726   py::gil_scoped_acquire acquire;
727   auto &liveModules = contextRef->liveModules;
728   auto it = liveModules.find(module.ptr);
729   if (it == liveModules.end()) {
730     // Create.
731     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
732     // Note that the default return value policy on cast is automatic_reference,
733     // which does not take ownership (delete will not be called).
734     // Just be explicit.
735     py::object pyRef =
736         py::cast(unownedModule, py::return_value_policy::take_ownership);
737     unownedModule->handle = pyRef;
738     liveModules[module.ptr] =
739         std::make_pair(unownedModule->handle, unownedModule);
740     return PyModuleRef(unownedModule, std::move(pyRef));
741   }
742   // Use existing.
743   PyModule *existing = it->second.second;
744   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
745   return PyModuleRef(existing, std::move(pyRef));
746 }
747 
748 py::object PyModule::createFromCapsule(py::object capsule) {
749   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
750   if (mlirModuleIsNull(rawModule))
751     throw py::error_already_set();
752   return forModule(rawModule).releaseObject();
753 }
754 
755 py::object PyModule::getCapsule() {
756   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
757 }
758 
759 //------------------------------------------------------------------------------
760 // PyOperation
761 //------------------------------------------------------------------------------
762 
763 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
764     : BaseContextObject(std::move(contextRef)), operation(operation) {}
765 
766 PyOperation::~PyOperation() {
767   // If the operation has already been invalidated there is nothing to do.
768   if (!valid)
769     return;
770   auto &liveOperations = getContext()->liveOperations;
771   assert(liveOperations.count(operation.ptr) == 1 &&
772          "destroying operation not in live map");
773   liveOperations.erase(operation.ptr);
774   if (!isAttached()) {
775     mlirOperationDestroy(operation);
776   }
777 }
778 
779 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
780                                            MlirOperation operation,
781                                            py::object parentKeepAlive) {
782   auto &liveOperations = contextRef->liveOperations;
783   // Create.
784   PyOperation *unownedOperation =
785       new PyOperation(std::move(contextRef), operation);
786   // Note that the default return value policy on cast is automatic_reference,
787   // which does not take ownership (delete will not be called).
788   // Just be explicit.
789   py::object pyRef =
790       py::cast(unownedOperation, py::return_value_policy::take_ownership);
791   unownedOperation->handle = pyRef;
792   if (parentKeepAlive) {
793     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
794   }
795   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
796   return PyOperationRef(unownedOperation, std::move(pyRef));
797 }
798 
799 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
800                                          MlirOperation operation,
801                                          py::object parentKeepAlive) {
802   auto &liveOperations = contextRef->liveOperations;
803   auto it = liveOperations.find(operation.ptr);
804   if (it == liveOperations.end()) {
805     // Create.
806     return createInstance(std::move(contextRef), operation,
807                           std::move(parentKeepAlive));
808   }
809   // Use existing.
810   PyOperation *existing = it->second.second;
811   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
812   return PyOperationRef(existing, std::move(pyRef));
813 }
814 
815 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
816                                            MlirOperation operation,
817                                            py::object parentKeepAlive) {
818   auto &liveOperations = contextRef->liveOperations;
819   assert(liveOperations.count(operation.ptr) == 0 &&
820          "cannot create detached operation that already exists");
821   (void)liveOperations;
822 
823   PyOperationRef created = createInstance(std::move(contextRef), operation,
824                                           std::move(parentKeepAlive));
825   created->attached = false;
826   return created;
827 }
828 
829 void PyOperation::checkValid() const {
830   if (!valid) {
831     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
832   }
833 }
834 
835 void PyOperationBase::print(py::object fileObject, bool binary,
836                             llvm::Optional<int64_t> largeElementsLimit,
837                             bool enableDebugInfo, bool prettyDebugInfo,
838                             bool printGenericOpForm, bool useLocalScope,
839                             bool assumeVerified) {
840   PyOperation &operation = getOperation();
841   operation.checkValid();
842   if (fileObject.is_none())
843     fileObject = py::module::import("sys").attr("stdout");
844 
845   if (!assumeVerified && !printGenericOpForm &&
846       !mlirOperationVerify(operation)) {
847     std::string message("// Verification failed, printing generic form\n");
848     if (binary) {
849       fileObject.attr("write")(py::bytes(message));
850     } else {
851       fileObject.attr("write")(py::str(message));
852     }
853     printGenericOpForm = true;
854   }
855 
856   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
857   if (largeElementsLimit)
858     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
859   if (enableDebugInfo)
860     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
861   if (printGenericOpForm)
862     mlirOpPrintingFlagsPrintGenericOpForm(flags);
863 
864   PyFileAccumulator accum(fileObject, binary);
865   py::gil_scoped_release();
866   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
867                               accum.getUserData());
868   mlirOpPrintingFlagsDestroy(flags);
869 }
870 
871 py::object PyOperationBase::getAsm(bool binary,
872                                    llvm::Optional<int64_t> largeElementsLimit,
873                                    bool enableDebugInfo, bool prettyDebugInfo,
874                                    bool printGenericOpForm, bool useLocalScope,
875                                    bool assumeVerified) {
876   py::object fileObject;
877   if (binary) {
878     fileObject = py::module::import("io").attr("BytesIO")();
879   } else {
880     fileObject = py::module::import("io").attr("StringIO")();
881   }
882   print(fileObject, /*binary=*/binary,
883         /*largeElementsLimit=*/largeElementsLimit,
884         /*enableDebugInfo=*/enableDebugInfo,
885         /*prettyDebugInfo=*/prettyDebugInfo,
886         /*printGenericOpForm=*/printGenericOpForm,
887         /*useLocalScope=*/useLocalScope,
888         /*assumeVerified=*/assumeVerified);
889 
890   return fileObject.attr("getvalue")();
891 }
892 
893 void PyOperationBase::moveAfter(PyOperationBase &other) {
894   PyOperation &operation = getOperation();
895   PyOperation &otherOp = other.getOperation();
896   operation.checkValid();
897   otherOp.checkValid();
898   mlirOperationMoveAfter(operation, otherOp);
899   operation.parentKeepAlive = otherOp.parentKeepAlive;
900 }
901 
902 void PyOperationBase::moveBefore(PyOperationBase &other) {
903   PyOperation &operation = getOperation();
904   PyOperation &otherOp = other.getOperation();
905   operation.checkValid();
906   otherOp.checkValid();
907   mlirOperationMoveBefore(operation, otherOp);
908   operation.parentKeepAlive = otherOp.parentKeepAlive;
909 }
910 
911 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
912   checkValid();
913   if (!isAttached())
914     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
915   MlirOperation operation = mlirOperationGetParentOperation(get());
916   if (mlirOperationIsNull(operation))
917     return {};
918   return PyOperation::forOperation(getContext(), operation);
919 }
920 
921 PyBlock PyOperation::getBlock() {
922   checkValid();
923   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
924   MlirBlock block = mlirOperationGetBlock(get());
925   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
926   assert(parentOperation && "Operation has no parent");
927   return PyBlock{std::move(*parentOperation), block};
928 }
929 
930 py::object PyOperation::getCapsule() {
931   checkValid();
932   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
933 }
934 
935 py::object PyOperation::createFromCapsule(py::object capsule) {
936   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
937   if (mlirOperationIsNull(rawOperation))
938     throw py::error_already_set();
939   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
940   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
941       .releaseObject();
942 }
943 
944 py::object PyOperation::create(
945     std::string name, llvm::Optional<std::vector<PyType *>> results,
946     llvm::Optional<std::vector<PyValue *>> operands,
947     llvm::Optional<py::dict> attributes,
948     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
949     DefaultingPyLocation location, py::object maybeIp) {
950   llvm::SmallVector<MlirValue, 4> mlirOperands;
951   llvm::SmallVector<MlirType, 4> mlirResults;
952   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
954 
955   // General parameter validation.
956   if (regions < 0)
957     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
958 
959   // Unpack/validate operands.
960   if (operands) {
961     mlirOperands.reserve(operands->size());
962     for (PyValue *operand : *operands) {
963       if (!operand)
964         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
965       mlirOperands.push_back(operand->get());
966     }
967   }
968 
969   // Unpack/validate results.
970   if (results) {
971     mlirResults.reserve(results->size());
972     for (PyType *result : *results) {
973       // TODO: Verify result type originate from the same context.
974       if (!result)
975         throw SetPyError(PyExc_ValueError, "result type cannot be None");
976       mlirResults.push_back(*result);
977     }
978   }
979   // Unpack/validate attributes.
980   if (attributes) {
981     mlirAttributes.reserve(attributes->size());
982     for (auto &it : *attributes) {
983       std::string key;
984       try {
985         key = it.first.cast<std::string>();
986       } catch (py::cast_error &err) {
987         std::string msg = "Invalid attribute key (not a string) when "
988                           "attempting to create the operation \"" +
989                           name + "\" (" + err.what() + ")";
990         throw py::cast_error(msg);
991       }
992       try {
993         auto &attribute = it.second.cast<PyAttribute &>();
994         // TODO: Verify attribute originates from the same context.
995         mlirAttributes.emplace_back(std::move(key), attribute);
996       } catch (py::reference_cast_error &) {
997         // This exception seems thrown when the value is "None".
998         std::string msg =
999             "Found an invalid (`None`?) attribute value for the key \"" + key +
1000             "\" when attempting to create the operation \"" + name + "\"";
1001         throw py::cast_error(msg);
1002       } catch (py::cast_error &err) {
1003         std::string msg = "Invalid attribute value for the key \"" + key +
1004                           "\" when attempting to create the operation \"" +
1005                           name + "\" (" + err.what() + ")";
1006         throw py::cast_error(msg);
1007       }
1008     }
1009   }
1010   // Unpack/validate successors.
1011   if (successors) {
1012     mlirSuccessors.reserve(successors->size());
1013     for (auto *successor : *successors) {
1014       // TODO: Verify successor originate from the same context.
1015       if (!successor)
1016         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1017       mlirSuccessors.push_back(successor->get());
1018     }
1019   }
1020 
1021   // Apply unpacked/validated to the operation state. Beyond this
1022   // point, exceptions cannot be thrown or else the state will leak.
1023   MlirOperationState state =
1024       mlirOperationStateGet(toMlirStringRef(name), location);
1025   if (!mlirOperands.empty())
1026     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1027                                   mlirOperands.data());
1028   if (!mlirResults.empty())
1029     mlirOperationStateAddResults(&state, mlirResults.size(),
1030                                  mlirResults.data());
1031   if (!mlirAttributes.empty()) {
1032     // Note that the attribute names directly reference bytes in
1033     // mlirAttributes, so that vector must not be changed from here
1034     // on.
1035     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1036     mlirNamedAttributes.reserve(mlirAttributes.size());
1037     for (auto &it : mlirAttributes)
1038       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1039           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1040                             toMlirStringRef(it.first)),
1041           it.second));
1042     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1043                                     mlirNamedAttributes.data());
1044   }
1045   if (!mlirSuccessors.empty())
1046     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1047                                     mlirSuccessors.data());
1048   if (regions) {
1049     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1050     mlirRegions.resize(regions);
1051     for (int i = 0; i < regions; ++i)
1052       mlirRegions[i] = mlirRegionCreate();
1053     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1054                                       mlirRegions.data());
1055   }
1056 
1057   // Construct the operation.
1058   MlirOperation operation = mlirOperationCreate(&state);
1059   PyOperationRef created =
1060       PyOperation::createDetached(location->getContext(), operation);
1061 
1062   // InsertPoint active?
1063   if (!maybeIp.is(py::cast(false))) {
1064     PyInsertionPoint *ip;
1065     if (maybeIp.is_none()) {
1066       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1067     } else {
1068       ip = py::cast<PyInsertionPoint *>(maybeIp);
1069     }
1070     if (ip)
1071       ip->insert(*created.get());
1072   }
1073 
1074   return created->createOpView();
1075 }
1076 
1077 py::object PyOperation::createOpView() {
1078   checkValid();
1079   MlirIdentifier ident = mlirOperationGetName(get());
1080   MlirStringRef identStr = mlirIdentifierStr(ident);
1081   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1082       StringRef(identStr.data, identStr.length));
1083   if (opViewClass)
1084     return (*opViewClass)(getRef().getObject());
1085   return py::cast(PyOpView(getRef().getObject()));
1086 }
1087 
1088 void PyOperation::erase() {
1089   checkValid();
1090   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1091   // Python reference to a child operation is live. All children should also
1092   // have their `valid` bit set to false.
1093   auto &liveOperations = getContext()->liveOperations;
1094   if (liveOperations.count(operation.ptr))
1095     liveOperations.erase(operation.ptr);
1096   mlirOperationDestroy(operation);
1097   valid = false;
1098 }
1099 
1100 //------------------------------------------------------------------------------
1101 // PyOpView
1102 //------------------------------------------------------------------------------
1103 
1104 py::object
1105 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1106                        py::list operandList,
1107                        llvm::Optional<py::dict> attributes,
1108                        llvm::Optional<std::vector<PyBlock *>> successors,
1109                        llvm::Optional<int> regions,
1110                        DefaultingPyLocation location, py::object maybeIp) {
1111   PyMlirContextRef context = location->getContext();
1112   // Class level operation construction metadata.
1113   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1114   // Operand and result segment specs are either none, which does no
1115   // variadic unpacking, or a list of ints with segment sizes, where each
1116   // element is either a positive number (typically 1 for a scalar) or -1 to
1117   // indicate that it is derived from the length of the same-indexed operand
1118   // or result (implying that it is a list at that position).
1119   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1120   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1121 
1122   std::vector<uint32_t> operandSegmentLengths;
1123   std::vector<uint32_t> resultSegmentLengths;
1124 
1125   // Validate/determine region count.
1126   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1127   int opMinRegionCount = std::get<0>(opRegionSpec);
1128   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1129   if (!regions) {
1130     regions = opMinRegionCount;
1131   }
1132   if (*regions < opMinRegionCount) {
1133     throw py::value_error(
1134         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1135          llvm::Twine(opMinRegionCount) +
1136          " regions but was built with regions=" + llvm::Twine(*regions))
1137             .str());
1138   }
1139   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1140     throw py::value_error(
1141         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1142          llvm::Twine(opMinRegionCount) +
1143          " regions but was built with regions=" + llvm::Twine(*regions))
1144             .str());
1145   }
1146 
1147   // Unpack results.
1148   std::vector<PyType *> resultTypes;
1149   resultTypes.reserve(resultTypeList.size());
1150   if (resultSegmentSpecObj.is_none()) {
1151     // Non-variadic result unpacking.
1152     for (auto it : llvm::enumerate(resultTypeList)) {
1153       try {
1154         resultTypes.push_back(py::cast<PyType *>(it.value()));
1155         if (!resultTypes.back())
1156           throw py::cast_error();
1157       } catch (py::cast_error &err) {
1158         throw py::value_error((llvm::Twine("Result ") +
1159                                llvm::Twine(it.index()) + " of operation \"" +
1160                                name + "\" must be a Type (" + err.what() + ")")
1161                                   .str());
1162       }
1163     }
1164   } else {
1165     // Sized result unpacking.
1166     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1167     if (resultSegmentSpec.size() != resultTypeList.size()) {
1168       throw py::value_error((llvm::Twine("Operation \"") + name +
1169                              "\" requires " +
1170                              llvm::Twine(resultSegmentSpec.size()) +
1171                              " result segments but was provided " +
1172                              llvm::Twine(resultTypeList.size()))
1173                                 .str());
1174     }
1175     resultSegmentLengths.reserve(resultTypeList.size());
1176     for (auto it :
1177          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1178       int segmentSpec = std::get<1>(it.value());
1179       if (segmentSpec == 1 || segmentSpec == 0) {
1180         // Unpack unary element.
1181         try {
1182           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1183           if (resultType) {
1184             resultTypes.push_back(resultType);
1185             resultSegmentLengths.push_back(1);
1186           } else if (segmentSpec == 0) {
1187             // Allowed to be optional.
1188             resultSegmentLengths.push_back(0);
1189           } else {
1190             throw py::cast_error("was None and result is not optional");
1191           }
1192         } catch (py::cast_error &err) {
1193           throw py::value_error((llvm::Twine("Result ") +
1194                                  llvm::Twine(it.index()) + " of operation \"" +
1195                                  name + "\" must be a Type (" + err.what() +
1196                                  ")")
1197                                     .str());
1198         }
1199       } else if (segmentSpec == -1) {
1200         // Unpack sequence by appending.
1201         try {
1202           if (std::get<0>(it.value()).is_none()) {
1203             // Treat it as an empty list.
1204             resultSegmentLengths.push_back(0);
1205           } else {
1206             // Unpack the list.
1207             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1208             for (py::object segmentItem : segment) {
1209               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1210               if (!resultTypes.back()) {
1211                 throw py::cast_error("contained a None item");
1212               }
1213             }
1214             resultSegmentLengths.push_back(segment.size());
1215           }
1216         } catch (std::exception &err) {
1217           // NOTE: Sloppy to be using a catch-all here, but there are at least
1218           // three different unrelated exceptions that can be thrown in the
1219           // above "casts". Just keep the scope above small and catch them all.
1220           throw py::value_error((llvm::Twine("Result ") +
1221                                  llvm::Twine(it.index()) + " of operation \"" +
1222                                  name + "\" must be a Sequence of Types (" +
1223                                  err.what() + ")")
1224                                     .str());
1225         }
1226       } else {
1227         throw py::value_error("Unexpected segment spec");
1228       }
1229     }
1230   }
1231 
1232   // Unpack operands.
1233   std::vector<PyValue *> operands;
1234   operands.reserve(operands.size());
1235   if (operandSegmentSpecObj.is_none()) {
1236     // Non-sized operand unpacking.
1237     for (auto it : llvm::enumerate(operandList)) {
1238       try {
1239         operands.push_back(py::cast<PyValue *>(it.value()));
1240         if (!operands.back())
1241           throw py::cast_error();
1242       } catch (py::cast_error &err) {
1243         throw py::value_error((llvm::Twine("Operand ") +
1244                                llvm::Twine(it.index()) + " of operation \"" +
1245                                name + "\" must be a Value (" + err.what() + ")")
1246                                   .str());
1247       }
1248     }
1249   } else {
1250     // Sized operand unpacking.
1251     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1252     if (operandSegmentSpec.size() != operandList.size()) {
1253       throw py::value_error((llvm::Twine("Operation \"") + name +
1254                              "\" requires " +
1255                              llvm::Twine(operandSegmentSpec.size()) +
1256                              "operand segments but was provided " +
1257                              llvm::Twine(operandList.size()))
1258                                 .str());
1259     }
1260     operandSegmentLengths.reserve(operandList.size());
1261     for (auto it :
1262          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1263       int segmentSpec = std::get<1>(it.value());
1264       if (segmentSpec == 1 || segmentSpec == 0) {
1265         // Unpack unary element.
1266         try {
1267           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1268           if (operandValue) {
1269             operands.push_back(operandValue);
1270             operandSegmentLengths.push_back(1);
1271           } else if (segmentSpec == 0) {
1272             // Allowed to be optional.
1273             operandSegmentLengths.push_back(0);
1274           } else {
1275             throw py::cast_error("was None and operand is not optional");
1276           }
1277         } catch (py::cast_error &err) {
1278           throw py::value_error((llvm::Twine("Operand ") +
1279                                  llvm::Twine(it.index()) + " of operation \"" +
1280                                  name + "\" must be a Value (" + err.what() +
1281                                  ")")
1282                                     .str());
1283         }
1284       } else if (segmentSpec == -1) {
1285         // Unpack sequence by appending.
1286         try {
1287           if (std::get<0>(it.value()).is_none()) {
1288             // Treat it as an empty list.
1289             operandSegmentLengths.push_back(0);
1290           } else {
1291             // Unpack the list.
1292             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1293             for (py::object segmentItem : segment) {
1294               operands.push_back(py::cast<PyValue *>(segmentItem));
1295               if (!operands.back()) {
1296                 throw py::cast_error("contained a None item");
1297               }
1298             }
1299             operandSegmentLengths.push_back(segment.size());
1300           }
1301         } catch (std::exception &err) {
1302           // NOTE: Sloppy to be using a catch-all here, but there are at least
1303           // three different unrelated exceptions that can be thrown in the
1304           // above "casts". Just keep the scope above small and catch them all.
1305           throw py::value_error((llvm::Twine("Operand ") +
1306                                  llvm::Twine(it.index()) + " of operation \"" +
1307                                  name + "\" must be a Sequence of Values (" +
1308                                  err.what() + ")")
1309                                     .str());
1310         }
1311       } else {
1312         throw py::value_error("Unexpected segment spec");
1313       }
1314     }
1315   }
1316 
1317   // Merge operand/result segment lengths into attributes if needed.
1318   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1319     // Dup.
1320     if (attributes) {
1321       attributes = py::dict(*attributes);
1322     } else {
1323       attributes = py::dict();
1324     }
1325     if (attributes->contains("result_segment_sizes") ||
1326         attributes->contains("operand_segment_sizes")) {
1327       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1328                             "'operand_segment_sizes' attribute is unsupported. "
1329                             "Use Operation.create for such low-level access.");
1330     }
1331 
1332     // Add result_segment_sizes attribute.
1333     if (!resultSegmentLengths.empty()) {
1334       int64_t size = resultSegmentLengths.size();
1335       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1336           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1337           resultSegmentLengths.size(), resultSegmentLengths.data());
1338       (*attributes)["result_segment_sizes"] =
1339           PyAttribute(context, segmentLengthAttr);
1340     }
1341 
1342     // Add operand_segment_sizes attribute.
1343     if (!operandSegmentLengths.empty()) {
1344       int64_t size = operandSegmentLengths.size();
1345       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1346           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1347           operandSegmentLengths.size(), operandSegmentLengths.data());
1348       (*attributes)["operand_segment_sizes"] =
1349           PyAttribute(context, segmentLengthAttr);
1350     }
1351   }
1352 
1353   // Delegate to create.
1354   return PyOperation::create(std::move(name),
1355                              /*results=*/std::move(resultTypes),
1356                              /*operands=*/std::move(operands),
1357                              /*attributes=*/std::move(attributes),
1358                              /*successors=*/std::move(successors),
1359                              /*regions=*/*regions, location, maybeIp);
1360 }
1361 
1362 PyOpView::PyOpView(py::object operationObject)
1363     // Casting through the PyOperationBase base-class and then back to the
1364     // Operation lets us accept any PyOperationBase subclass.
1365     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1366       operationObject(operation.getRef().getObject()) {}
1367 
1368 py::object PyOpView::createRawSubclass(py::object userClass) {
1369   // This is... a little gross. The typical pattern is to have a pure python
1370   // class that extends OpView like:
1371   //   class AddFOp(_cext.ir.OpView):
1372   //     def __init__(self, loc, lhs, rhs):
1373   //       operation = loc.context.create_operation(
1374   //           "addf", lhs, rhs, results=[lhs.type])
1375   //       super().__init__(operation)
1376   //
1377   // I.e. The goal of the user facing type is to provide a nice constructor
1378   // that has complete freedom for the op under construction. This is at odds
1379   // with our other desire to sometimes create this object by just passing an
1380   // operation (to initialize the base class). We could do *arg and **kwargs
1381   // munging to try to make it work, but instead, we synthesize a new class
1382   // on the fly which extends this user class (AddFOp in this example) and
1383   // *give it* the base class's __init__ method, thus bypassing the
1384   // intermediate subclass's __init__ method entirely. While slightly,
1385   // underhanded, this is safe/legal because the type hierarchy has not changed
1386   // (we just added a new leaf) and we aren't mucking around with __new__.
1387   // Typically, this new class will be stored on the original as "_Raw" and will
1388   // be used for casts and other things that need a variant of the class that
1389   // is initialized purely from an operation.
1390   py::object parentMetaclass =
1391       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1392   py::dict attributes;
1393   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1394   // now.
1395   //   auto opViewType = py::type::of<PyOpView>();
1396   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1397   attributes["__init__"] = opViewType.attr("__init__");
1398   py::str origName = userClass.attr("__name__");
1399   py::str newName = py::str("_") + origName;
1400   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1401 }
1402 
1403 //------------------------------------------------------------------------------
1404 // PyInsertionPoint.
1405 //------------------------------------------------------------------------------
1406 
1407 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1408 
1409 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1410     : refOperation(beforeOperationBase.getOperation().getRef()),
1411       block((*refOperation)->getBlock()) {}
1412 
1413 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1414   PyOperation &operation = operationBase.getOperation();
1415   if (operation.isAttached())
1416     throw SetPyError(PyExc_ValueError,
1417                      "Attempt to insert operation that is already attached");
1418   block.getParentOperation()->checkValid();
1419   MlirOperation beforeOp = {nullptr};
1420   if (refOperation) {
1421     // Insert before operation.
1422     (*refOperation)->checkValid();
1423     beforeOp = (*refOperation)->get();
1424   } else {
1425     // Insert at end (before null) is only valid if the block does not
1426     // already end in a known terminator (violating this will cause assertion
1427     // failures later).
1428     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1429       throw py::index_error("Cannot insert operation at the end of a block "
1430                             "that already has a terminator. Did you mean to "
1431                             "use 'InsertionPoint.at_block_terminator(block)' "
1432                             "versus 'InsertionPoint(block)'?");
1433     }
1434   }
1435   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1436   operation.setAttached();
1437 }
1438 
1439 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1440   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1441   if (mlirOperationIsNull(firstOp)) {
1442     // Just insert at end.
1443     return PyInsertionPoint(block);
1444   }
1445 
1446   // Insert before first op.
1447   PyOperationRef firstOpRef = PyOperation::forOperation(
1448       block.getParentOperation()->getContext(), firstOp);
1449   return PyInsertionPoint{block, std::move(firstOpRef)};
1450 }
1451 
1452 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1453   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1454   if (mlirOperationIsNull(terminator))
1455     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1456   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1457       block.getParentOperation()->getContext(), terminator);
1458   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1459 }
1460 
1461 py::object PyInsertionPoint::contextEnter() {
1462   return PyThreadContextEntry::pushInsertionPoint(*this);
1463 }
1464 
1465 void PyInsertionPoint::contextExit(pybind11::object excType,
1466                                    pybind11::object excVal,
1467                                    pybind11::object excTb) {
1468   PyThreadContextEntry::popInsertionPoint(*this);
1469 }
1470 
1471 //------------------------------------------------------------------------------
1472 // PyAttribute.
1473 //------------------------------------------------------------------------------
1474 
1475 bool PyAttribute::operator==(const PyAttribute &other) {
1476   return mlirAttributeEqual(attr, other.attr);
1477 }
1478 
1479 py::object PyAttribute::getCapsule() {
1480   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1481 }
1482 
1483 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1484   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1485   if (mlirAttributeIsNull(rawAttr))
1486     throw py::error_already_set();
1487   return PyAttribute(
1488       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1489 }
1490 
1491 //------------------------------------------------------------------------------
1492 // PyNamedAttribute.
1493 //------------------------------------------------------------------------------
1494 
1495 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1496     : ownedName(new std::string(std::move(ownedName))) {
1497   namedAttr = mlirNamedAttributeGet(
1498       mlirIdentifierGet(mlirAttributeGetContext(attr),
1499                         toMlirStringRef(*this->ownedName)),
1500       attr);
1501 }
1502 
1503 //------------------------------------------------------------------------------
1504 // PyType.
1505 //------------------------------------------------------------------------------
1506 
1507 bool PyType::operator==(const PyType &other) {
1508   return mlirTypeEqual(type, other.type);
1509 }
1510 
1511 py::object PyType::getCapsule() {
1512   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1513 }
1514 
1515 PyType PyType::createFromCapsule(py::object capsule) {
1516   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1517   if (mlirTypeIsNull(rawType))
1518     throw py::error_already_set();
1519   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1520                 rawType);
1521 }
1522 
1523 //------------------------------------------------------------------------------
1524 // PyValue and subclases.
1525 //------------------------------------------------------------------------------
1526 
1527 pybind11::object PyValue::getCapsule() {
1528   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1529 }
1530 
1531 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1532   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1533   if (mlirValueIsNull(value))
1534     throw py::error_already_set();
1535   MlirOperation owner;
1536   if (mlirValueIsAOpResult(value))
1537     owner = mlirOpResultGetOwner(value);
1538   if (mlirValueIsABlockArgument(value))
1539     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1540   if (mlirOperationIsNull(owner))
1541     throw py::error_already_set();
1542   MlirContext ctx = mlirOperationGetContext(owner);
1543   PyOperationRef ownerRef =
1544       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1545   return PyValue(ownerRef, value);
1546 }
1547 
1548 //------------------------------------------------------------------------------
1549 // PySymbolTable.
1550 //------------------------------------------------------------------------------
1551 
1552 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1553     : operation(operation.getOperation().getRef()) {
1554   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1555   if (mlirSymbolTableIsNull(symbolTable)) {
1556     throw py::cast_error("Operation is not a Symbol Table.");
1557   }
1558 }
1559 
1560 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1561   operation->checkValid();
1562   MlirOperation symbol = mlirSymbolTableLookup(
1563       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1564   if (mlirOperationIsNull(symbol))
1565     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1566 
1567   return PyOperation::forOperation(operation->getContext(), symbol,
1568                                    operation.getObject())
1569       ->createOpView();
1570 }
1571 
1572 void PySymbolTable::erase(PyOperationBase &symbol) {
1573   operation->checkValid();
1574   symbol.getOperation().checkValid();
1575   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1576   // The operation is also erased, so we must invalidate it. There may be Python
1577   // references to this operation so we don't want to delete it from the list of
1578   // live operations here.
1579   symbol.getOperation().valid = false;
1580 }
1581 
1582 void PySymbolTable::dunderDel(const std::string &name) {
1583   py::object operation = dunderGetItem(name);
1584   erase(py::cast<PyOperationBase &>(operation));
1585 }
1586 
1587 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1588   operation->checkValid();
1589   symbol.getOperation().checkValid();
1590   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1591       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1592   if (mlirAttributeIsNull(symbolAttr))
1593     throw py::value_error("Expected operation to have a symbol name.");
1594   return PyAttribute(
1595       symbol.getOperation().getContext(),
1596       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1597 }
1598 
1599 namespace {
1600 /// CRTP base class for Python MLIR values that subclass Value and should be
1601 /// castable from it. The value hierarchy is one level deep and is not supposed
1602 /// to accommodate other levels unless core MLIR changes.
1603 template <typename DerivedTy>
1604 class PyConcreteValue : public PyValue {
1605 public:
1606   // Derived classes must define statics for:
1607   //   IsAFunctionTy isaFunction
1608   //   const char *pyClassName
1609   // and redefine bindDerived.
1610   using ClassTy = py::class_<DerivedTy, PyValue>;
1611   using IsAFunctionTy = bool (*)(MlirValue);
1612 
1613   PyConcreteValue() = default;
1614   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1615       : PyValue(operationRef, value) {}
1616   PyConcreteValue(PyValue &orig)
1617       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1618 
1619   /// Attempts to cast the original value to the derived type and throws on
1620   /// type mismatches.
1621   static MlirValue castFrom(PyValue &orig) {
1622     if (!DerivedTy::isaFunction(orig.get())) {
1623       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1624       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1625                                              DerivedTy::pyClassName +
1626                                              " (from " + origRepr + ")");
1627     }
1628     return orig.get();
1629   }
1630 
1631   /// Binds the Python module objects to functions of this class.
1632   static void bind(py::module &m) {
1633     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1634     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1635     cls.def_static(
1636         "isinstance",
1637         [](PyValue &otherValue) -> bool {
1638           return DerivedTy::isaFunction(otherValue);
1639         },
1640         py::arg("other_value"));
1641     DerivedTy::bindDerived(cls);
1642   }
1643 
1644   /// Implemented by derived classes to add methods to the Python subclass.
1645   static void bindDerived(ClassTy &m) {}
1646 };
1647 
1648 /// Python wrapper for MlirBlockArgument.
1649 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1650 public:
1651   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1652   static constexpr const char *pyClassName = "BlockArgument";
1653   using PyConcreteValue::PyConcreteValue;
1654 
1655   static void bindDerived(ClassTy &c) {
1656     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1657       return PyBlock(self.getParentOperation(),
1658                      mlirBlockArgumentGetOwner(self.get()));
1659     });
1660     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1661       return mlirBlockArgumentGetArgNumber(self.get());
1662     });
1663     c.def(
1664         "set_type",
1665         [](PyBlockArgument &self, PyType type) {
1666           return mlirBlockArgumentSetType(self.get(), type);
1667         },
1668         py::arg("type"));
1669   }
1670 };
1671 
1672 /// Python wrapper for MlirOpResult.
1673 class PyOpResult : public PyConcreteValue<PyOpResult> {
1674 public:
1675   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1676   static constexpr const char *pyClassName = "OpResult";
1677   using PyConcreteValue::PyConcreteValue;
1678 
1679   static void bindDerived(ClassTy &c) {
1680     c.def_property_readonly("owner", [](PyOpResult &self) {
1681       assert(
1682           mlirOperationEqual(self.getParentOperation()->get(),
1683                              mlirOpResultGetOwner(self.get())) &&
1684           "expected the owner of the value in Python to match that in the IR");
1685       return self.getParentOperation().getObject();
1686     });
1687     c.def_property_readonly("result_number", [](PyOpResult &self) {
1688       return mlirOpResultGetResultNumber(self.get());
1689     });
1690   }
1691 };
1692 
1693 /// Returns the list of types of the values held by container.
1694 template <typename Container>
1695 static std::vector<PyType> getValueTypes(Container &container,
1696                                          PyMlirContextRef &context) {
1697   std::vector<PyType> result;
1698   result.reserve(container.getNumElements());
1699   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1700     result.push_back(
1701         PyType(context, mlirValueGetType(container.getElement(i).get())));
1702   }
1703   return result;
1704 }
1705 
1706 /// A list of block arguments. Internally, these are stored as consecutive
1707 /// elements, random access is cheap. The argument list is associated with the
1708 /// operation that contains the block (detached blocks are not allowed in
1709 /// Python bindings) and extends its lifetime.
1710 class PyBlockArgumentList
1711     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1712 public:
1713   static constexpr const char *pyClassName = "BlockArgumentList";
1714 
1715   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1716                       intptr_t startIndex = 0, intptr_t length = -1,
1717                       intptr_t step = 1)
1718       : Sliceable(startIndex,
1719                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1720                   step),
1721         operation(std::move(operation)), block(block) {}
1722 
1723   /// Returns the number of arguments in the list.
1724   intptr_t getNumElements() {
1725     operation->checkValid();
1726     return mlirBlockGetNumArguments(block);
1727   }
1728 
1729   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1730   PyBlockArgument getElement(intptr_t pos) {
1731     MlirValue argument = mlirBlockGetArgument(block, pos);
1732     return PyBlockArgument(operation, argument);
1733   }
1734 
1735   /// Returns a sublist of this list.
1736   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1737                             intptr_t step) {
1738     return PyBlockArgumentList(operation, block, startIndex, length, step);
1739   }
1740 
1741   static void bindDerived(ClassTy &c) {
1742     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1743       return getValueTypes(self, self.operation->getContext());
1744     });
1745   }
1746 
1747 private:
1748   PyOperationRef operation;
1749   MlirBlock block;
1750 };
1751 
1752 /// A list of operation operands. Internally, these are stored as consecutive
1753 /// elements, random access is cheap. The result list is associated with the
1754 /// operation whose results these are, and extends the lifetime of this
1755 /// operation.
1756 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1757 public:
1758   static constexpr const char *pyClassName = "OpOperandList";
1759 
1760   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1761                   intptr_t length = -1, intptr_t step = 1)
1762       : Sliceable(startIndex,
1763                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1764                                : length,
1765                   step),
1766         operation(operation) {}
1767 
1768   intptr_t getNumElements() {
1769     operation->checkValid();
1770     return mlirOperationGetNumOperands(operation->get());
1771   }
1772 
1773   PyValue getElement(intptr_t pos) {
1774     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1775     MlirOperation owner;
1776     if (mlirValueIsAOpResult(operand))
1777       owner = mlirOpResultGetOwner(operand);
1778     else if (mlirValueIsABlockArgument(operand))
1779       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1780     else
1781       assert(false && "Value must be an block arg or op result.");
1782     PyOperationRef pyOwner =
1783         PyOperation::forOperation(operation->getContext(), owner);
1784     return PyValue(pyOwner, operand);
1785   }
1786 
1787   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1788     return PyOpOperandList(operation, startIndex, length, step);
1789   }
1790 
1791   void dunderSetItem(intptr_t index, PyValue value) {
1792     index = wrapIndex(index);
1793     mlirOperationSetOperand(operation->get(), index, value.get());
1794   }
1795 
1796   static void bindDerived(ClassTy &c) {
1797     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1798   }
1799 
1800 private:
1801   PyOperationRef operation;
1802 };
1803 
1804 /// A list of operation results. Internally, these are stored as consecutive
1805 /// elements, random access is cheap. The result list is associated with the
1806 /// operation whose results these are, and extends the lifetime of this
1807 /// operation.
1808 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1809 public:
1810   static constexpr const char *pyClassName = "OpResultList";
1811 
1812   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1813                  intptr_t length = -1, intptr_t step = 1)
1814       : Sliceable(startIndex,
1815                   length == -1 ? mlirOperationGetNumResults(operation->get())
1816                                : length,
1817                   step),
1818         operation(operation) {}
1819 
1820   intptr_t getNumElements() {
1821     operation->checkValid();
1822     return mlirOperationGetNumResults(operation->get());
1823   }
1824 
1825   PyOpResult getElement(intptr_t index) {
1826     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1827     return PyOpResult(value);
1828   }
1829 
1830   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1831     return PyOpResultList(operation, startIndex, length, step);
1832   }
1833 
1834   static void bindDerived(ClassTy &c) {
1835     c.def_property_readonly("types", [](PyOpResultList &self) {
1836       return getValueTypes(self, self.operation->getContext());
1837     });
1838   }
1839 
1840 private:
1841   PyOperationRef operation;
1842 };
1843 
1844 /// A list of operation attributes. Can be indexed by name, producing
1845 /// attributes, or by index, producing named attributes.
1846 class PyOpAttributeMap {
1847 public:
1848   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1849 
1850   PyAttribute dunderGetItemNamed(const std::string &name) {
1851     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1852                                                          toMlirStringRef(name));
1853     if (mlirAttributeIsNull(attr)) {
1854       throw SetPyError(PyExc_KeyError,
1855                        "attempt to access a non-existent attribute");
1856     }
1857     return PyAttribute(operation->getContext(), attr);
1858   }
1859 
1860   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1861     if (index < 0 || index >= dunderLen()) {
1862       throw SetPyError(PyExc_IndexError,
1863                        "attempt to access out of bounds attribute");
1864     }
1865     MlirNamedAttribute namedAttr =
1866         mlirOperationGetAttribute(operation->get(), index);
1867     return PyNamedAttribute(
1868         namedAttr.attribute,
1869         std::string(mlirIdentifierStr(namedAttr.name).data,
1870                     mlirIdentifierStr(namedAttr.name).length));
1871   }
1872 
1873   void dunderSetItem(const std::string &name, PyAttribute attr) {
1874     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1875                                     attr);
1876   }
1877 
1878   void dunderDelItem(const std::string &name) {
1879     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1880                                                      toMlirStringRef(name));
1881     if (!removed)
1882       throw SetPyError(PyExc_KeyError,
1883                        "attempt to delete a non-existent attribute");
1884   }
1885 
1886   intptr_t dunderLen() {
1887     return mlirOperationGetNumAttributes(operation->get());
1888   }
1889 
1890   bool dunderContains(const std::string &name) {
1891     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1892         operation->get(), toMlirStringRef(name)));
1893   }
1894 
1895   static void bind(py::module &m) {
1896     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1897         .def("__contains__", &PyOpAttributeMap::dunderContains)
1898         .def("__len__", &PyOpAttributeMap::dunderLen)
1899         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1900         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1901         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1902         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1903   }
1904 
1905 private:
1906   PyOperationRef operation;
1907 };
1908 
1909 } // end namespace
1910 
1911 //------------------------------------------------------------------------------
1912 // Populates the core exports of the 'ir' submodule.
1913 //------------------------------------------------------------------------------
1914 
1915 void mlir::python::populateIRCore(py::module &m) {
1916   //----------------------------------------------------------------------------
1917   // Mapping of MlirContext.
1918   //----------------------------------------------------------------------------
1919   py::class_<PyMlirContext>(m, "Context", py::module_local())
1920       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1921       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1922       .def("_get_context_again",
1923            [](PyMlirContext &self) {
1924              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1925              return ref.releaseObject();
1926            })
1927       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1928       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1929       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1930                              &PyMlirContext::getCapsule)
1931       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1932       .def("__enter__", &PyMlirContext::contextEnter)
1933       .def("__exit__", &PyMlirContext::contextExit)
1934       .def_property_readonly_static(
1935           "current",
1936           [](py::object & /*class*/) {
1937             auto *context = PyThreadContextEntry::getDefaultContext();
1938             if (!context)
1939               throw SetPyError(PyExc_ValueError, "No current Context");
1940             return context;
1941           },
1942           "Gets the Context bound to the current thread or raises ValueError")
1943       .def_property_readonly(
1944           "dialects",
1945           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1946           "Gets a container for accessing dialects by name")
1947       .def_property_readonly(
1948           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1949           "Alias for 'dialect'")
1950       .def(
1951           "get_dialect_descriptor",
1952           [=](PyMlirContext &self, std::string &name) {
1953             MlirDialect dialect = mlirContextGetOrLoadDialect(
1954                 self.get(), {name.data(), name.size()});
1955             if (mlirDialectIsNull(dialect)) {
1956               throw SetPyError(PyExc_ValueError,
1957                                Twine("Dialect '") + name + "' not found");
1958             }
1959             return PyDialectDescriptor(self.getRef(), dialect);
1960           },
1961           py::arg("dialect_name"),
1962           "Gets or loads a dialect by name, returning its descriptor object")
1963       .def_property(
1964           "allow_unregistered_dialects",
1965           [](PyMlirContext &self) -> bool {
1966             return mlirContextGetAllowUnregisteredDialects(self.get());
1967           },
1968           [](PyMlirContext &self, bool value) {
1969             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1970           })
1971       .def(
1972           "enable_multithreading",
1973           [](PyMlirContext &self, bool enable) {
1974             mlirContextEnableMultithreading(self.get(), enable);
1975           },
1976           py::arg("enable"))
1977       .def(
1978           "is_registered_operation",
1979           [](PyMlirContext &self, std::string &name) {
1980             return mlirContextIsRegisteredOperation(
1981                 self.get(), MlirStringRef{name.data(), name.size()});
1982           },
1983           py::arg("operation_name"));
1984 
1985   //----------------------------------------------------------------------------
1986   // Mapping of PyDialectDescriptor
1987   //----------------------------------------------------------------------------
1988   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1989       .def_property_readonly("namespace",
1990                              [](PyDialectDescriptor &self) {
1991                                MlirStringRef ns =
1992                                    mlirDialectGetNamespace(self.get());
1993                                return py::str(ns.data, ns.length);
1994                              })
1995       .def("__repr__", [](PyDialectDescriptor &self) {
1996         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1997         std::string repr("<DialectDescriptor ");
1998         repr.append(ns.data, ns.length);
1999         repr.append(">");
2000         return repr;
2001       });
2002 
2003   //----------------------------------------------------------------------------
2004   // Mapping of PyDialects
2005   //----------------------------------------------------------------------------
2006   py::class_<PyDialects>(m, "Dialects", py::module_local())
2007       .def("__getitem__",
2008            [=](PyDialects &self, std::string keyName) {
2009              MlirDialect dialect =
2010                  self.getDialectForKey(keyName, /*attrError=*/false);
2011              py::object descriptor =
2012                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2013              return createCustomDialectWrapper(keyName, std::move(descriptor));
2014            })
2015       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2016         MlirDialect dialect =
2017             self.getDialectForKey(attrName, /*attrError=*/true);
2018         py::object descriptor =
2019             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2020         return createCustomDialectWrapper(attrName, std::move(descriptor));
2021       });
2022 
2023   //----------------------------------------------------------------------------
2024   // Mapping of PyDialect
2025   //----------------------------------------------------------------------------
2026   py::class_<PyDialect>(m, "Dialect", py::module_local())
2027       .def(py::init<py::object>(), py::arg("descriptor"))
2028       .def_property_readonly(
2029           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2030       .def("__repr__", [](py::object self) {
2031         auto clazz = self.attr("__class__");
2032         return py::str("<Dialect ") +
2033                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2034                clazz.attr("__module__") + py::str(".") +
2035                clazz.attr("__name__") + py::str(")>");
2036       });
2037 
2038   //----------------------------------------------------------------------------
2039   // Mapping of Location
2040   //----------------------------------------------------------------------------
2041   py::class_<PyLocation>(m, "Location", py::module_local())
2042       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2043       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2044       .def("__enter__", &PyLocation::contextEnter)
2045       .def("__exit__", &PyLocation::contextExit)
2046       .def("__eq__",
2047            [](PyLocation &self, PyLocation &other) -> bool {
2048              return mlirLocationEqual(self, other);
2049            })
2050       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2051       .def_property_readonly_static(
2052           "current",
2053           [](py::object & /*class*/) {
2054             auto *loc = PyThreadContextEntry::getDefaultLocation();
2055             if (!loc)
2056               throw SetPyError(PyExc_ValueError, "No current Location");
2057             return loc;
2058           },
2059           "Gets the Location bound to the current thread or raises ValueError")
2060       .def_static(
2061           "unknown",
2062           [](DefaultingPyMlirContext context) {
2063             return PyLocation(context->getRef(),
2064                               mlirLocationUnknownGet(context->get()));
2065           },
2066           py::arg("context") = py::none(),
2067           "Gets a Location representing an unknown location")
2068       .def_static(
2069           "callsite",
2070           [](PyLocation callee, const std::vector<PyLocation> &frames,
2071              DefaultingPyMlirContext context) {
2072             if (frames.empty())
2073               throw py::value_error("No caller frames provided");
2074             MlirLocation caller = frames.back().get();
2075             for (const PyLocation &frame :
2076                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2077               caller = mlirLocationCallSiteGet(frame.get(), caller);
2078             return PyLocation(context->getRef(),
2079                               mlirLocationCallSiteGet(callee.get(), caller));
2080           },
2081           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2082           kContextGetCallSiteLocationDocstring)
2083       .def_static(
2084           "file",
2085           [](std::string filename, int line, int col,
2086              DefaultingPyMlirContext context) {
2087             return PyLocation(
2088                 context->getRef(),
2089                 mlirLocationFileLineColGet(
2090                     context->get(), toMlirStringRef(filename), line, col));
2091           },
2092           py::arg("filename"), py::arg("line"), py::arg("col"),
2093           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2094       .def_static(
2095           "name",
2096           [](std::string name, llvm::Optional<PyLocation> childLoc,
2097              DefaultingPyMlirContext context) {
2098             return PyLocation(
2099                 context->getRef(),
2100                 mlirLocationNameGet(
2101                     context->get(), toMlirStringRef(name),
2102                     childLoc ? childLoc->get()
2103                              : mlirLocationUnknownGet(context->get())));
2104           },
2105           py::arg("name"), py::arg("childLoc") = py::none(),
2106           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2107       .def_property_readonly(
2108           "context",
2109           [](PyLocation &self) { return self.getContext().getObject(); },
2110           "Context that owns the Location")
2111       .def("__repr__", [](PyLocation &self) {
2112         PyPrintAccumulator printAccum;
2113         mlirLocationPrint(self, printAccum.getCallback(),
2114                           printAccum.getUserData());
2115         return printAccum.join();
2116       });
2117 
2118   //----------------------------------------------------------------------------
2119   // Mapping of Module
2120   //----------------------------------------------------------------------------
2121   py::class_<PyModule>(m, "Module", py::module_local())
2122       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2123       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2124       .def_static(
2125           "parse",
2126           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2127             MlirModule module = mlirModuleCreateParse(
2128                 context->get(), toMlirStringRef(moduleAsm));
2129             // TODO: Rework error reporting once diagnostic engine is exposed
2130             // in C API.
2131             if (mlirModuleIsNull(module)) {
2132               throw SetPyError(
2133                   PyExc_ValueError,
2134                   "Unable to parse module assembly (see diagnostics)");
2135             }
2136             return PyModule::forModule(module).releaseObject();
2137           },
2138           py::arg("asm"), py::arg("context") = py::none(),
2139           kModuleParseDocstring)
2140       .def_static(
2141           "create",
2142           [](DefaultingPyLocation loc) {
2143             MlirModule module = mlirModuleCreateEmpty(loc);
2144             return PyModule::forModule(module).releaseObject();
2145           },
2146           py::arg("loc") = py::none(), "Creates an empty module")
2147       .def_property_readonly(
2148           "context",
2149           [](PyModule &self) { return self.getContext().getObject(); },
2150           "Context that created the Module")
2151       .def_property_readonly(
2152           "operation",
2153           [](PyModule &self) {
2154             return PyOperation::forOperation(self.getContext(),
2155                                              mlirModuleGetOperation(self.get()),
2156                                              self.getRef().releaseObject())
2157                 .releaseObject();
2158           },
2159           "Accesses the module as an operation")
2160       .def_property_readonly(
2161           "body",
2162           [](PyModule &self) {
2163             PyOperationRef module_op = PyOperation::forOperation(
2164                 self.getContext(), mlirModuleGetOperation(self.get()),
2165                 self.getRef().releaseObject());
2166             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2167             return returnBlock;
2168           },
2169           "Return the block for this module")
2170       .def(
2171           "dump",
2172           [](PyModule &self) {
2173             mlirOperationDump(mlirModuleGetOperation(self.get()));
2174           },
2175           kDumpDocstring)
2176       .def(
2177           "__str__",
2178           [](py::object self) {
2179             // Defer to the operation's __str__.
2180             return self.attr("operation").attr("__str__")();
2181           },
2182           kOperationStrDunderDocstring);
2183 
2184   //----------------------------------------------------------------------------
2185   // Mapping of Operation.
2186   //----------------------------------------------------------------------------
2187   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2188       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2189                              [](PyOperationBase &self) {
2190                                return self.getOperation().getCapsule();
2191                              })
2192       .def("__eq__",
2193            [](PyOperationBase &self, PyOperationBase &other) {
2194              return &self.getOperation() == &other.getOperation();
2195            })
2196       .def("__eq__",
2197            [](PyOperationBase &self, py::object other) { return false; })
2198       .def("__hash__",
2199            [](PyOperationBase &self) {
2200              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2201            })
2202       .def_property_readonly("attributes",
2203                              [](PyOperationBase &self) {
2204                                return PyOpAttributeMap(
2205                                    self.getOperation().getRef());
2206                              })
2207       .def_property_readonly("operands",
2208                              [](PyOperationBase &self) {
2209                                return PyOpOperandList(
2210                                    self.getOperation().getRef());
2211                              })
2212       .def_property_readonly("regions",
2213                              [](PyOperationBase &self) {
2214                                return PyRegionList(
2215                                    self.getOperation().getRef());
2216                              })
2217       .def_property_readonly(
2218           "results",
2219           [](PyOperationBase &self) {
2220             return PyOpResultList(self.getOperation().getRef());
2221           },
2222           "Returns the list of Operation results.")
2223       .def_property_readonly(
2224           "result",
2225           [](PyOperationBase &self) {
2226             auto &operation = self.getOperation();
2227             auto numResults = mlirOperationGetNumResults(operation);
2228             if (numResults != 1) {
2229               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2230               throw SetPyError(
2231                   PyExc_ValueError,
2232                   Twine("Cannot call .result on operation ") +
2233                       StringRef(name.data, name.length) + " which has " +
2234                       Twine(numResults) +
2235                       " results (it is only valid for operations with a "
2236                       "single result)");
2237             }
2238             return PyOpResult(operation.getRef(),
2239                               mlirOperationGetResult(operation, 0));
2240           },
2241           "Shortcut to get an op result if it has only one (throws an error "
2242           "otherwise).")
2243       .def_property_readonly(
2244           "location",
2245           [](PyOperationBase &self) {
2246             PyOperation &operation = self.getOperation();
2247             return PyLocation(operation.getContext(),
2248                               mlirOperationGetLocation(operation.get()));
2249           },
2250           "Returns the source location the operation was defined or derived "
2251           "from.")
2252       .def(
2253           "__str__",
2254           [](PyOperationBase &self) {
2255             return self.getAsm(/*binary=*/false,
2256                                /*largeElementsLimit=*/llvm::None,
2257                                /*enableDebugInfo=*/false,
2258                                /*prettyDebugInfo=*/false,
2259                                /*printGenericOpForm=*/false,
2260                                /*useLocalScope=*/false,
2261                                /*assumeVerified=*/false);
2262           },
2263           "Returns the assembly form of the operation.")
2264       .def("print", &PyOperationBase::print,
2265            // Careful: Lots of arguments must match up with print method.
2266            py::arg("file") = py::none(), py::arg("binary") = false,
2267            py::arg("large_elements_limit") = py::none(),
2268            py::arg("enable_debug_info") = false,
2269            py::arg("pretty_debug_info") = false,
2270            py::arg("print_generic_op_form") = false,
2271            py::arg("use_local_scope") = false,
2272            py::arg("assume_verified") = false, kOperationPrintDocstring)
2273       .def("get_asm", &PyOperationBase::getAsm,
2274            // Careful: Lots of arguments must match up with get_asm method.
2275            py::arg("binary") = false,
2276            py::arg("large_elements_limit") = py::none(),
2277            py::arg("enable_debug_info") = false,
2278            py::arg("pretty_debug_info") = false,
2279            py::arg("print_generic_op_form") = false,
2280            py::arg("use_local_scope") = false,
2281            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2282       .def(
2283           "verify",
2284           [](PyOperationBase &self) {
2285             return mlirOperationVerify(self.getOperation());
2286           },
2287           "Verify the operation and return true if it passes, false if it "
2288           "fails.")
2289       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2290            "Puts self immediately after the other operation in its parent "
2291            "block.")
2292       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2293            "Puts self immediately before the other operation in its parent "
2294            "block.")
2295       .def(
2296           "detach_from_parent",
2297           [](PyOperationBase &self) {
2298             PyOperation &operation = self.getOperation();
2299             operation.checkValid();
2300             if (!operation.isAttached())
2301               throw py::value_error("Detached operation has no parent.");
2302 
2303             operation.detachFromParent();
2304             return operation.createOpView();
2305           },
2306           "Detaches the operation from its parent block.");
2307 
2308   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2309       .def_static("create", &PyOperation::create, py::arg("name"),
2310                   py::arg("results") = py::none(),
2311                   py::arg("operands") = py::none(),
2312                   py::arg("attributes") = py::none(),
2313                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2314                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2315                   kOperationCreateDocstring)
2316       .def_property_readonly("parent",
2317                              [](PyOperation &self) -> py::object {
2318                                auto parent = self.getParentOperation();
2319                                if (parent)
2320                                  return parent->getObject();
2321                                return py::none();
2322                              })
2323       .def("erase", &PyOperation::erase)
2324       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2325                              &PyOperation::getCapsule)
2326       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2327       .def_property_readonly("name",
2328                              [](PyOperation &self) {
2329                                self.checkValid();
2330                                MlirOperation operation = self.get();
2331                                MlirStringRef name = mlirIdentifierStr(
2332                                    mlirOperationGetName(operation));
2333                                return py::str(name.data, name.length);
2334                              })
2335       .def_property_readonly(
2336           "context",
2337           [](PyOperation &self) {
2338             self.checkValid();
2339             return self.getContext().getObject();
2340           },
2341           "Context that owns the Operation")
2342       .def_property_readonly("opview", &PyOperation::createOpView);
2343 
2344   auto opViewClass =
2345       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2346           .def(py::init<py::object>(), py::arg("operation"))
2347           .def_property_readonly("operation", &PyOpView::getOperationObject)
2348           .def_property_readonly(
2349               "context",
2350               [](PyOpView &self) {
2351                 return self.getOperation().getContext().getObject();
2352               },
2353               "Context that owns the Operation")
2354           .def("__str__", [](PyOpView &self) {
2355             return py::str(self.getOperationObject());
2356           });
2357   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2358   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2359   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2360   opViewClass.attr("build_generic") = classmethod(
2361       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2362       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2363       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2364       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2365       "Builds a specific, generated OpView based on class level attributes.");
2366 
2367   //----------------------------------------------------------------------------
2368   // Mapping of PyRegion.
2369   //----------------------------------------------------------------------------
2370   py::class_<PyRegion>(m, "Region", py::module_local())
2371       .def_property_readonly(
2372           "blocks",
2373           [](PyRegion &self) {
2374             return PyBlockList(self.getParentOperation(), self.get());
2375           },
2376           "Returns a forward-optimized sequence of blocks.")
2377       .def_property_readonly(
2378           "owner",
2379           [](PyRegion &self) {
2380             return self.getParentOperation()->createOpView();
2381           },
2382           "Returns the operation owning this region.")
2383       .def(
2384           "__iter__",
2385           [](PyRegion &self) {
2386             self.checkValid();
2387             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2388             return PyBlockIterator(self.getParentOperation(), firstBlock);
2389           },
2390           "Iterates over blocks in the region.")
2391       .def("__eq__",
2392            [](PyRegion &self, PyRegion &other) {
2393              return self.get().ptr == other.get().ptr;
2394            })
2395       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2396 
2397   //----------------------------------------------------------------------------
2398   // Mapping of PyBlock.
2399   //----------------------------------------------------------------------------
2400   py::class_<PyBlock>(m, "Block", py::module_local())
2401       .def_property_readonly(
2402           "owner",
2403           [](PyBlock &self) {
2404             return self.getParentOperation()->createOpView();
2405           },
2406           "Returns the owning operation of this block.")
2407       .def_property_readonly(
2408           "region",
2409           [](PyBlock &self) {
2410             MlirRegion region = mlirBlockGetParentRegion(self.get());
2411             return PyRegion(self.getParentOperation(), region);
2412           },
2413           "Returns the owning region of this block.")
2414       .def_property_readonly(
2415           "arguments",
2416           [](PyBlock &self) {
2417             return PyBlockArgumentList(self.getParentOperation(), self.get());
2418           },
2419           "Returns a list of block arguments.")
2420       .def_property_readonly(
2421           "operations",
2422           [](PyBlock &self) {
2423             return PyOperationList(self.getParentOperation(), self.get());
2424           },
2425           "Returns a forward-optimized sequence of operations.")
2426       .def_static(
2427           "create_at_start",
2428           [](PyRegion &parent, py::list pyArgTypes) {
2429             parent.checkValid();
2430             llvm::SmallVector<MlirType, 4> argTypes;
2431             argTypes.reserve(pyArgTypes.size());
2432             for (auto &pyArg : pyArgTypes) {
2433               argTypes.push_back(pyArg.cast<PyType &>());
2434             }
2435 
2436             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2437             mlirRegionInsertOwnedBlock(parent, 0, block);
2438             return PyBlock(parent.getParentOperation(), block);
2439           },
2440           py::arg("parent"), py::arg("arg_types") = py::list(),
2441           "Creates and returns a new Block at the beginning of the given "
2442           "region (with given argument types).")
2443       .def(
2444           "create_before",
2445           [](PyBlock &self, py::args pyArgTypes) {
2446             self.checkValid();
2447             llvm::SmallVector<MlirType, 4> argTypes;
2448             argTypes.reserve(pyArgTypes.size());
2449             for (auto &pyArg : pyArgTypes) {
2450               argTypes.push_back(pyArg.cast<PyType &>());
2451             }
2452 
2453             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2454             MlirRegion region = mlirBlockGetParentRegion(self.get());
2455             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2456             return PyBlock(self.getParentOperation(), block);
2457           },
2458           "Creates and returns a new Block before this block "
2459           "(with given argument types).")
2460       .def(
2461           "create_after",
2462           [](PyBlock &self, py::args pyArgTypes) {
2463             self.checkValid();
2464             llvm::SmallVector<MlirType, 4> argTypes;
2465             argTypes.reserve(pyArgTypes.size());
2466             for (auto &pyArg : pyArgTypes) {
2467               argTypes.push_back(pyArg.cast<PyType &>());
2468             }
2469 
2470             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2471             MlirRegion region = mlirBlockGetParentRegion(self.get());
2472             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2473             return PyBlock(self.getParentOperation(), block);
2474           },
2475           "Creates and returns a new Block after this block "
2476           "(with given argument types).")
2477       .def(
2478           "__iter__",
2479           [](PyBlock &self) {
2480             self.checkValid();
2481             MlirOperation firstOperation =
2482                 mlirBlockGetFirstOperation(self.get());
2483             return PyOperationIterator(self.getParentOperation(),
2484                                        firstOperation);
2485           },
2486           "Iterates over operations in the block.")
2487       .def("__eq__",
2488            [](PyBlock &self, PyBlock &other) {
2489              return self.get().ptr == other.get().ptr;
2490            })
2491       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2492       .def(
2493           "__str__",
2494           [](PyBlock &self) {
2495             self.checkValid();
2496             PyPrintAccumulator printAccum;
2497             mlirBlockPrint(self.get(), printAccum.getCallback(),
2498                            printAccum.getUserData());
2499             return printAccum.join();
2500           },
2501           "Returns the assembly form of the block.")
2502       .def(
2503           "append",
2504           [](PyBlock &self, PyOperationBase &operation) {
2505             if (operation.getOperation().isAttached())
2506               operation.getOperation().detachFromParent();
2507 
2508             MlirOperation mlirOperation = operation.getOperation().get();
2509             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2510             operation.getOperation().setAttached(
2511                 self.getParentOperation().getObject());
2512           },
2513           py::arg("operation"),
2514           "Appends an operation to this block. If the operation is currently "
2515           "in another block, it will be moved.");
2516 
2517   //----------------------------------------------------------------------------
2518   // Mapping of PyInsertionPoint.
2519   //----------------------------------------------------------------------------
2520 
2521   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2522       .def(py::init<PyBlock &>(), py::arg("block"),
2523            "Inserts after the last operation but still inside the block.")
2524       .def("__enter__", &PyInsertionPoint::contextEnter)
2525       .def("__exit__", &PyInsertionPoint::contextExit)
2526       .def_property_readonly_static(
2527           "current",
2528           [](py::object & /*class*/) {
2529             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2530             if (!ip)
2531               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2532             return ip;
2533           },
2534           "Gets the InsertionPoint bound to the current thread or raises "
2535           "ValueError if none has been set")
2536       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2537            "Inserts before a referenced operation.")
2538       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2539                   py::arg("block"), "Inserts at the beginning of the block.")
2540       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2541                   py::arg("block"), "Inserts before the block terminator.")
2542       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2543            "Inserts an operation.")
2544       .def_property_readonly(
2545           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2546           "Returns the block that this InsertionPoint points to.");
2547 
2548   //----------------------------------------------------------------------------
2549   // Mapping of PyAttribute.
2550   //----------------------------------------------------------------------------
2551   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2552       // Delegate to the PyAttribute copy constructor, which will also lifetime
2553       // extend the backing context which owns the MlirAttribute.
2554       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2555            "Casts the passed attribute to the generic Attribute")
2556       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2557                              &PyAttribute::getCapsule)
2558       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2559       .def_static(
2560           "parse",
2561           [](std::string attrSpec, DefaultingPyMlirContext context) {
2562             MlirAttribute type = mlirAttributeParseGet(
2563                 context->get(), toMlirStringRef(attrSpec));
2564             // TODO: Rework error reporting once diagnostic engine is exposed
2565             // in C API.
2566             if (mlirAttributeIsNull(type)) {
2567               throw SetPyError(PyExc_ValueError,
2568                                Twine("Unable to parse attribute: '") +
2569                                    attrSpec + "'");
2570             }
2571             return PyAttribute(context->getRef(), type);
2572           },
2573           py::arg("asm"), py::arg("context") = py::none(),
2574           "Parses an attribute from an assembly form")
2575       .def_property_readonly(
2576           "context",
2577           [](PyAttribute &self) { return self.getContext().getObject(); },
2578           "Context that owns the Attribute")
2579       .def_property_readonly("type",
2580                              [](PyAttribute &self) {
2581                                return PyType(self.getContext()->getRef(),
2582                                              mlirAttributeGetType(self));
2583                              })
2584       .def(
2585           "get_named",
2586           [](PyAttribute &self, std::string name) {
2587             return PyNamedAttribute(self, std::move(name));
2588           },
2589           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2590       .def("__eq__",
2591            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2592       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2593       .def("__hash__",
2594            [](PyAttribute &self) {
2595              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2596            })
2597       .def(
2598           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2599           kDumpDocstring)
2600       .def(
2601           "__str__",
2602           [](PyAttribute &self) {
2603             PyPrintAccumulator printAccum;
2604             mlirAttributePrint(self, printAccum.getCallback(),
2605                                printAccum.getUserData());
2606             return printAccum.join();
2607           },
2608           "Returns the assembly form of the Attribute.")
2609       .def("__repr__", [](PyAttribute &self) {
2610         // Generally, assembly formats are not printed for __repr__ because
2611         // this can cause exceptionally long debug output and exceptions.
2612         // However, attribute values are generally considered useful and are
2613         // printed. This may need to be re-evaluated if debug dumps end up
2614         // being excessive.
2615         PyPrintAccumulator printAccum;
2616         printAccum.parts.append("Attribute(");
2617         mlirAttributePrint(self, printAccum.getCallback(),
2618                            printAccum.getUserData());
2619         printAccum.parts.append(")");
2620         return printAccum.join();
2621       });
2622 
2623   //----------------------------------------------------------------------------
2624   // Mapping of PyNamedAttribute
2625   //----------------------------------------------------------------------------
2626   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2627       .def("__repr__",
2628            [](PyNamedAttribute &self) {
2629              PyPrintAccumulator printAccum;
2630              printAccum.parts.append("NamedAttribute(");
2631              printAccum.parts.append(
2632                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2633                          mlirIdentifierStr(self.namedAttr.name).length));
2634              printAccum.parts.append("=");
2635              mlirAttributePrint(self.namedAttr.attribute,
2636                                 printAccum.getCallback(),
2637                                 printAccum.getUserData());
2638              printAccum.parts.append(")");
2639              return printAccum.join();
2640            })
2641       .def_property_readonly(
2642           "name",
2643           [](PyNamedAttribute &self) {
2644             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2645                            mlirIdentifierStr(self.namedAttr.name).length);
2646           },
2647           "The name of the NamedAttribute binding")
2648       .def_property_readonly(
2649           "attr",
2650           [](PyNamedAttribute &self) {
2651             // TODO: When named attribute is removed/refactored, also remove
2652             // this constructor (it does an inefficient table lookup).
2653             auto contextRef = PyMlirContext::forContext(
2654                 mlirAttributeGetContext(self.namedAttr.attribute));
2655             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2656           },
2657           py::keep_alive<0, 1>(),
2658           "The underlying generic attribute of the NamedAttribute binding");
2659 
2660   //----------------------------------------------------------------------------
2661   // Mapping of PyType.
2662   //----------------------------------------------------------------------------
2663   py::class_<PyType>(m, "Type", py::module_local())
2664       // Delegate to the PyType copy constructor, which will also lifetime
2665       // extend the backing context which owns the MlirType.
2666       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2667            "Casts the passed type to the generic Type")
2668       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2669       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2670       .def_static(
2671           "parse",
2672           [](std::string typeSpec, DefaultingPyMlirContext context) {
2673             MlirType type =
2674                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2675             // TODO: Rework error reporting once diagnostic engine is exposed
2676             // in C API.
2677             if (mlirTypeIsNull(type)) {
2678               throw SetPyError(PyExc_ValueError,
2679                                Twine("Unable to parse type: '") + typeSpec +
2680                                    "'");
2681             }
2682             return PyType(context->getRef(), type);
2683           },
2684           py::arg("asm"), py::arg("context") = py::none(),
2685           kContextParseTypeDocstring)
2686       .def_property_readonly(
2687           "context", [](PyType &self) { return self.getContext().getObject(); },
2688           "Context that owns the Type")
2689       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2690       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2691       .def("__hash__",
2692            [](PyType &self) {
2693              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2694            })
2695       .def(
2696           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2697       .def(
2698           "__str__",
2699           [](PyType &self) {
2700             PyPrintAccumulator printAccum;
2701             mlirTypePrint(self, printAccum.getCallback(),
2702                           printAccum.getUserData());
2703             return printAccum.join();
2704           },
2705           "Returns the assembly form of the type.")
2706       .def("__repr__", [](PyType &self) {
2707         // Generally, assembly formats are not printed for __repr__ because
2708         // this can cause exceptionally long debug output and exceptions.
2709         // However, types are an exception as they typically have compact
2710         // assembly forms and printing them is useful.
2711         PyPrintAccumulator printAccum;
2712         printAccum.parts.append("Type(");
2713         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2714         printAccum.parts.append(")");
2715         return printAccum.join();
2716       });
2717 
2718   //----------------------------------------------------------------------------
2719   // Mapping of Value.
2720   //----------------------------------------------------------------------------
2721   py::class_<PyValue>(m, "Value", py::module_local())
2722       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2723       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2724       .def_property_readonly(
2725           "context",
2726           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2727           "Context in which the value lives.")
2728       .def(
2729           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2730           kDumpDocstring)
2731       .def_property_readonly(
2732           "owner",
2733           [](PyValue &self) {
2734             assert(mlirOperationEqual(self.getParentOperation()->get(),
2735                                       mlirOpResultGetOwner(self.get())) &&
2736                    "expected the owner of the value in Python to match that in "
2737                    "the IR");
2738             return self.getParentOperation().getObject();
2739           })
2740       .def("__eq__",
2741            [](PyValue &self, PyValue &other) {
2742              return self.get().ptr == other.get().ptr;
2743            })
2744       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2745       .def("__hash__",
2746            [](PyValue &self) {
2747              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2748            })
2749       .def(
2750           "__str__",
2751           [](PyValue &self) {
2752             PyPrintAccumulator printAccum;
2753             printAccum.parts.append("Value(");
2754             mlirValuePrint(self.get(), printAccum.getCallback(),
2755                            printAccum.getUserData());
2756             printAccum.parts.append(")");
2757             return printAccum.join();
2758           },
2759           kValueDunderStrDocstring)
2760       .def_property_readonly("type", [](PyValue &self) {
2761         return PyType(self.getParentOperation()->getContext(),
2762                       mlirValueGetType(self.get()));
2763       });
2764   PyBlockArgument::bind(m);
2765   PyOpResult::bind(m);
2766 
2767   //----------------------------------------------------------------------------
2768   // Mapping of SymbolTable.
2769   //----------------------------------------------------------------------------
2770   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2771       .def(py::init<PyOperationBase &>())
2772       .def("__getitem__", &PySymbolTable::dunderGetItem)
2773       .def("insert", &PySymbolTable::insert, py::arg("operation"))
2774       .def("erase", &PySymbolTable::erase, py::arg("operation"))
2775       .def("__delitem__", &PySymbolTable::dunderDel)
2776       .def("__contains__", [](PySymbolTable &table, const std::string &name) {
2777         return !mlirOperationIsNull(mlirSymbolTableLookup(
2778             table, mlirStringRefCreate(name.data(), name.length())));
2779       });
2780 
2781   // Container bindings.
2782   PyBlockArgumentList::bind(m);
2783   PyBlockIterator::bind(m);
2784   PyBlockList::bind(m);
2785   PyOperationIterator::bind(m);
2786   PyOperationList::bind(m);
2787   PyOpAttributeMap::bind(m);
2788   PyOpOperandList::bind(m);
2789   PyOpResultList::bind(m);
2790   PyRegionIterator::bind(m);
2791   PyRegionList::bind(m);
2792 
2793   // Debug bindings.
2794   PyGlobalDebugFlag::bind(m);
2795 }
2796