1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96 )";
97 
98 static const char kOperationGetAsmDocstring[] =
99     R"(Gets the assembly form of the operation with all options available.
100 
101 Args:
102   binary: Whether to return a bytes (True) or str (False) object. Defaults to
103     False.
104   ... others ...: See the print() method for common keyword arguments for
105     configuring the printout.
106 Returns:
107   Either a bytes or str object, depending on the setting of the 'binary'
108   argument.
109 )";
110 
111 static const char kOperationStrDunderDocstring[] =
112     R"(Gets the assembly form of the operation with default options.
113 
114 If more advanced control over the assembly formatting or I/O options is needed,
115 use the dedicated print or get_asm method, which supports keyword arguments to
116 customize behavior.
117 )";
118 
119 static const char kDumpDocstring[] =
120     R"(Dumps a debug representation of the object to stderr.)";
121 
122 static const char kAppendBlockDocstring[] =
123     R"(Appends a new block, with argument types as positional args.
124 
125 Returns:
126   The created block.
127 )";
128 
129 static const char kValueDunderStrDocstring[] =
130     R"(Returns the string form of the value.
131 
132 If the value is a block argument, this is the assembly form of its type and the
133 position in the argument list. If the value is an operation result, this is
134 equivalent to printing the operation that produced it.
135 )";
136 
137 //------------------------------------------------------------------------------
138 // Utilities.
139 //------------------------------------------------------------------------------
140 
141 /// Helper for creating an @classmethod.
142 template <class Func, typename... Args>
143 py::object classmethod(Func f, Args... args) {
144   py::object cf = py::cpp_function(f, args...);
145   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
146 }
147 
148 static py::object
149 createCustomDialectWrapper(const std::string &dialectNamespace,
150                            py::object dialectDescriptor) {
151   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
152   if (!dialectClass) {
153     // Use the base class.
154     return py::cast(PyDialect(std::move(dialectDescriptor)));
155   }
156 
157   // Create the custom implementation.
158   return (*dialectClass)(std::move(dialectDescriptor));
159 }
160 
161 static MlirStringRef toMlirStringRef(const std::string &s) {
162   return mlirStringRefCreate(s.data(), s.size());
163 }
164 
165 /// Wrapper for the global LLVM debugging flag.
166 struct PyGlobalDebugFlag {
167   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
168 
169   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
170 
171   static void bind(py::module &m) {
172     // Debug flags.
173     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
174         .def_property_static("flag", &PyGlobalDebugFlag::get,
175                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
176   }
177 };
178 
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182 
183 namespace {
184 
185 class PyRegionIterator {
186 public:
187   PyRegionIterator(PyOperationRef operation)
188       : operation(std::move(operation)) {}
189 
190   PyRegionIterator &dunderIter() { return *this; }
191 
192   PyRegion dunderNext() {
193     operation->checkValid();
194     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195       throw py::stop_iteration();
196     }
197     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198     return PyRegion(operation, region);
199   }
200 
201   static void bind(py::module &m) {
202     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
203         .def("__iter__", &PyRegionIterator::dunderIter)
204         .def("__next__", &PyRegionIterator::dunderNext);
205   }
206 
207 private:
208   PyOperationRef operation;
209   int nextIndex = 0;
210 };
211 
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
216   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217 
218   intptr_t dunderLen() {
219     operation->checkValid();
220     return mlirOperationGetNumRegions(operation->get());
221   }
222 
223   PyRegion dunderGetItem(intptr_t index) {
224     // dunderLen checks validity.
225     if (index < 0 || index >= dunderLen()) {
226       throw SetPyError(PyExc_IndexError,
227                        "attempt to access out of bounds region");
228     }
229     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230     return PyRegion(operation, region);
231   }
232 
233   static void bind(py::module &m) {
234     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
235         .def("__len__", &PyRegionList::dunderLen)
236         .def("__getitem__", &PyRegionList::dunderGetItem);
237   }
238 
239 private:
240   PyOperationRef operation;
241 };
242 
243 class PyBlockIterator {
244 public:
245   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246       : operation(std::move(operation)), next(next) {}
247 
248   PyBlockIterator &dunderIter() { return *this; }
249 
250   PyBlock dunderNext() {
251     operation->checkValid();
252     if (mlirBlockIsNull(next)) {
253       throw py::stop_iteration();
254     }
255 
256     PyBlock returnBlock(operation, next);
257     next = mlirBlockGetNextInRegion(next);
258     return returnBlock;
259   }
260 
261   static void bind(py::module &m) {
262     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
263         .def("__iter__", &PyBlockIterator::dunderIter)
264         .def("__next__", &PyBlockIterator::dunderNext);
265   }
266 
267 private:
268   PyOperationRef operation;
269   MlirBlock next;
270 };
271 
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
277   PyBlockList(PyOperationRef operation, MlirRegion region)
278       : operation(std::move(operation)), region(region) {}
279 
280   PyBlockIterator dunderIter() {
281     operation->checkValid();
282     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283   }
284 
285   intptr_t dunderLen() {
286     operation->checkValid();
287     intptr_t count = 0;
288     MlirBlock block = mlirRegionGetFirstBlock(region);
289     while (!mlirBlockIsNull(block)) {
290       count += 1;
291       block = mlirBlockGetNextInRegion(block);
292     }
293     return count;
294   }
295 
296   PyBlock dunderGetItem(intptr_t index) {
297     operation->checkValid();
298     if (index < 0) {
299       throw SetPyError(PyExc_IndexError,
300                        "attempt to access out of bounds block");
301     }
302     MlirBlock block = mlirRegionGetFirstBlock(region);
303     while (!mlirBlockIsNull(block)) {
304       if (index == 0) {
305         return PyBlock(operation, block);
306       }
307       block = mlirBlockGetNextInRegion(block);
308       index -= 1;
309     }
310     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311   }
312 
313   PyBlock appendBlock(py::args pyArgTypes) {
314     operation->checkValid();
315     llvm::SmallVector<MlirType, 4> argTypes;
316     argTypes.reserve(pyArgTypes.size());
317     for (auto &pyArg : pyArgTypes) {
318       argTypes.push_back(pyArg.cast<PyType &>());
319     }
320 
321     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322     mlirRegionAppendOwnedBlock(region, block);
323     return PyBlock(operation, block);
324   }
325 
326   static void bind(py::module &m) {
327     py::class_<PyBlockList>(m, "BlockList", py::module_local())
328         .def("__getitem__", &PyBlockList::dunderGetItem)
329         .def("__iter__", &PyBlockList::dunderIter)
330         .def("__len__", &PyBlockList::dunderLen)
331         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332   }
333 
334 private:
335   PyOperationRef operation;
336   MlirRegion region;
337 };
338 
339 class PyOperationIterator {
340 public:
341   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342       : parentOperation(std::move(parentOperation)), next(next) {}
343 
344   PyOperationIterator &dunderIter() { return *this; }
345 
346   py::object dunderNext() {
347     parentOperation->checkValid();
348     if (mlirOperationIsNull(next)) {
349       throw py::stop_iteration();
350     }
351 
352     PyOperationRef returnOperation =
353         PyOperation::forOperation(parentOperation->getContext(), next);
354     next = mlirOperationGetNextInBlock(next);
355     return returnOperation->createOpView();
356   }
357 
358   static void bind(py::module &m) {
359     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
360         .def("__iter__", &PyOperationIterator::dunderIter)
361         .def("__next__", &PyOperationIterator::dunderNext);
362   }
363 
364 private:
365   PyOperationRef parentOperation;
366   MlirOperation next;
367 };
368 
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
375   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376       : parentOperation(std::move(parentOperation)), block(block) {}
377 
378   PyOperationIterator dunderIter() {
379     parentOperation->checkValid();
380     return PyOperationIterator(parentOperation,
381                                mlirBlockGetFirstOperation(block));
382   }
383 
384   intptr_t dunderLen() {
385     parentOperation->checkValid();
386     intptr_t count = 0;
387     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388     while (!mlirOperationIsNull(childOp)) {
389       count += 1;
390       childOp = mlirOperationGetNextInBlock(childOp);
391     }
392     return count;
393   }
394 
395   py::object dunderGetItem(intptr_t index) {
396     parentOperation->checkValid();
397     if (index < 0) {
398       throw SetPyError(PyExc_IndexError,
399                        "attempt to access out of bounds operation");
400     }
401     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402     while (!mlirOperationIsNull(childOp)) {
403       if (index == 0) {
404         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405             ->createOpView();
406       }
407       childOp = mlirOperationGetNextInBlock(childOp);
408       index -= 1;
409     }
410     throw SetPyError(PyExc_IndexError,
411                      "attempt to access out of bounds operation");
412   }
413 
414   static void bind(py::module &m) {
415     py::class_<PyOperationList>(m, "OperationList", py::module_local())
416         .def("__getitem__", &PyOperationList::dunderGetItem)
417         .def("__iter__", &PyOperationList::dunderIter)
418         .def("__len__", &PyOperationList::dunderLen);
419   }
420 
421 private:
422   PyOperationRef parentOperation;
423   MlirBlock block;
424 };
425 
426 } // namespace
427 
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431 
432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433   py::gil_scoped_acquire acquire;
434   auto &liveContexts = getLiveContexts();
435   liveContexts[context.ptr] = this;
436 }
437 
438 PyMlirContext::~PyMlirContext() {
439   // Note that the only public way to construct an instance is via the
440   // forContext method, which always puts the associated handle into
441   // liveContexts.
442   py::gil_scoped_acquire acquire;
443   getLiveContexts().erase(context.ptr);
444   mlirContextDestroy(context);
445 }
446 
447 py::object PyMlirContext::getCapsule() {
448   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450 
451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453   if (mlirContextIsNull(rawContext))
454     throw py::error_already_set();
455   return forContext(rawContext).releaseObject();
456 }
457 
458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459   MlirContext context = mlirContextCreate();
460   mlirRegisterAllDialects(context);
461   return new PyMlirContext(context);
462 }
463 
464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465   py::gil_scoped_acquire acquire;
466   auto &liveContexts = getLiveContexts();
467   auto it = liveContexts.find(context.ptr);
468   if (it == liveContexts.end()) {
469     // Create.
470     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471     py::object pyRef = py::cast(unownedContextWrapper);
472     assert(pyRef && "cast to py::object failed");
473     liveContexts[context.ptr] = unownedContextWrapper;
474     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475   }
476   // Use existing.
477   py::object pyRef = py::cast(it->second);
478   return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480 
481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482   static LiveContextMap liveContexts;
483   return liveContexts;
484 }
485 
486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487 
488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489 
490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491 
492 pybind11::object PyMlirContext::contextEnter() {
493   return PyThreadContextEntry::pushContext(*this);
494 }
495 
496 void PyMlirContext::contextExit(pybind11::object excType,
497                                 pybind11::object excVal,
498                                 pybind11::object excTb) {
499   PyThreadContextEntry::popContext(*this);
500 }
501 
502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504   if (!context) {
505     throw SetPyError(
506         PyExc_RuntimeError,
507         "An MLIR function requires a Context but none was provided in the call "
508         "or from the surrounding environment. Either pass to the function with "
509         "a 'context=' argument or establish a default using 'with Context():'");
510   }
511   return *context;
512 }
513 
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517 
518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519   static thread_local std::vector<PyThreadContextEntry> stack;
520   return stack;
521 }
522 
523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524   auto &stack = getStack();
525   if (stack.empty())
526     return nullptr;
527   return &stack.back();
528 }
529 
530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531                                 py::object insertionPoint,
532                                 py::object location) {
533   auto &stack = getStack();
534   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535                      std::move(location));
536   // If the new stack has more than one entry and the context of the new top
537   // entry matches the previous, copy the insertionPoint and location from the
538   // previous entry if missing from the new top entry.
539   if (stack.size() > 1) {
540     auto &prev = *(stack.rbegin() + 1);
541     auto &current = stack.back();
542     if (current.context.is(prev.context)) {
543       // Default non-context objects from the previous entry.
544       if (!current.insertionPoint)
545         current.insertionPoint = prev.insertionPoint;
546       if (!current.location)
547         current.location = prev.location;
548     }
549   }
550 }
551 
552 PyMlirContext *PyThreadContextEntry::getContext() {
553   if (!context)
554     return nullptr;
555   return py::cast<PyMlirContext *>(context);
556 }
557 
558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559   if (!insertionPoint)
560     return nullptr;
561   return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563 
564 PyLocation *PyThreadContextEntry::getLocation() {
565   if (!location)
566     return nullptr;
567   return py::cast<PyLocation *>(location);
568 }
569 
570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571   auto *tos = getTopOfStack();
572   return tos ? tos->getContext() : nullptr;
573 }
574 
575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576   auto *tos = getTopOfStack();
577   return tos ? tos->getInsertionPoint() : nullptr;
578 }
579 
580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getLocation() : nullptr;
583 }
584 
585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586   py::object contextObj = py::cast(context);
587   push(FrameKind::Context, /*context=*/contextObj,
588        /*insertionPoint=*/py::object(),
589        /*location=*/py::object());
590   return contextObj;
591 }
592 
593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594   auto &stack = getStack();
595   if (stack.empty())
596     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600   stack.pop_back();
601 }
602 
603 py::object
604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605   py::object contextObj =
606       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607   py::object insertionPointObj = py::cast(insertionPoint);
608   push(FrameKind::InsertionPoint,
609        /*context=*/contextObj,
610        /*insertionPoint=*/insertionPointObj,
611        /*location=*/py::object());
612   return insertionPointObj;
613 }
614 
615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616   auto &stack = getStack();
617   if (stack.empty())
618     throw SetPyError(PyExc_RuntimeError,
619                      "Unbalanced InsertionPoint enter/exit");
620   auto &tos = stack.back();
621   if (tos.frameKind != FrameKind::InsertionPoint &&
622       tos.getInsertionPoint() != &insertionPoint)
623     throw SetPyError(PyExc_RuntimeError,
624                      "Unbalanced InsertionPoint enter/exit");
625   stack.pop_back();
626 }
627 
628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629   py::object contextObj = location.getContext().getObject();
630   py::object locationObj = py::cast(location);
631   push(FrameKind::Location, /*context=*/contextObj,
632        /*insertionPoint=*/py::object(),
633        /*location=*/locationObj);
634   return locationObj;
635 }
636 
637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638   auto &stack = getStack();
639   if (stack.empty())
640     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641   auto &tos = stack.back();
642   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644   stack.pop_back();
645 }
646 
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650 
651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652                                          bool attrError) {
653   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
654                                                     {key.data(), key.size()});
655   if (mlirDialectIsNull(dialect)) {
656     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
657                      Twine("Dialect '") + key + "' not found");
658   }
659   return dialect;
660 }
661 
662 //------------------------------------------------------------------------------
663 // PyLocation
664 //------------------------------------------------------------------------------
665 
666 py::object PyLocation::getCapsule() {
667   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
668 }
669 
670 PyLocation PyLocation::createFromCapsule(py::object capsule) {
671   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
672   if (mlirLocationIsNull(rawLoc))
673     throw py::error_already_set();
674   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
675                     rawLoc);
676 }
677 
678 py::object PyLocation::contextEnter() {
679   return PyThreadContextEntry::pushLocation(*this);
680 }
681 
682 void PyLocation::contextExit(py::object excType, py::object excVal,
683                              py::object excTb) {
684   PyThreadContextEntry::popLocation(*this);
685 }
686 
687 PyLocation &DefaultingPyLocation::resolve() {
688   auto *location = PyThreadContextEntry::getDefaultLocation();
689   if (!location) {
690     throw SetPyError(
691         PyExc_RuntimeError,
692         "An MLIR function requires a Location but none was provided in the "
693         "call or from the surrounding environment. Either pass to the function "
694         "with a 'loc=' argument or establish a default using 'with loc:'");
695   }
696   return *location;
697 }
698 
699 //------------------------------------------------------------------------------
700 // PyModule
701 //------------------------------------------------------------------------------
702 
703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
704     : BaseContextObject(std::move(contextRef)), module(module) {}
705 
706 PyModule::~PyModule() {
707   py::gil_scoped_acquire acquire;
708   auto &liveModules = getContext()->liveModules;
709   assert(liveModules.count(module.ptr) == 1 &&
710          "destroying module not in live map");
711   liveModules.erase(module.ptr);
712   mlirModuleDestroy(module);
713 }
714 
715 PyModuleRef PyModule::forModule(MlirModule module) {
716   MlirContext context = mlirModuleGetContext(module);
717   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
718 
719   py::gil_scoped_acquire acquire;
720   auto &liveModules = contextRef->liveModules;
721   auto it = liveModules.find(module.ptr);
722   if (it == liveModules.end()) {
723     // Create.
724     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
725     // Note that the default return value policy on cast is automatic_reference,
726     // which does not take ownership (delete will not be called).
727     // Just be explicit.
728     py::object pyRef =
729         py::cast(unownedModule, py::return_value_policy::take_ownership);
730     unownedModule->handle = pyRef;
731     liveModules[module.ptr] =
732         std::make_pair(unownedModule->handle, unownedModule);
733     return PyModuleRef(unownedModule, std::move(pyRef));
734   }
735   // Use existing.
736   PyModule *existing = it->second.second;
737   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
738   return PyModuleRef(existing, std::move(pyRef));
739 }
740 
741 py::object PyModule::createFromCapsule(py::object capsule) {
742   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
743   if (mlirModuleIsNull(rawModule))
744     throw py::error_already_set();
745   return forModule(rawModule).releaseObject();
746 }
747 
748 py::object PyModule::getCapsule() {
749   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
750 }
751 
752 //------------------------------------------------------------------------------
753 // PyOperation
754 //------------------------------------------------------------------------------
755 
756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
757     : BaseContextObject(std::move(contextRef)), operation(operation) {}
758 
759 PyOperation::~PyOperation() {
760   // If the operation has already been invalidated there is nothing to do.
761   if (!valid)
762     return;
763   auto &liveOperations = getContext()->liveOperations;
764   assert(liveOperations.count(operation.ptr) == 1 &&
765          "destroying operation not in live map");
766   liveOperations.erase(operation.ptr);
767   if (!isAttached()) {
768     mlirOperationDestroy(operation);
769   }
770 }
771 
772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773                                            MlirOperation operation,
774                                            py::object parentKeepAlive) {
775   auto &liveOperations = contextRef->liveOperations;
776   // Create.
777   PyOperation *unownedOperation =
778       new PyOperation(std::move(contextRef), operation);
779   // Note that the default return value policy on cast is automatic_reference,
780   // which does not take ownership (delete will not be called).
781   // Just be explicit.
782   py::object pyRef =
783       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784   unownedOperation->handle = pyRef;
785   if (parentKeepAlive) {
786     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787   }
788   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789   return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791 
792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793                                          MlirOperation operation,
794                                          py::object parentKeepAlive) {
795   auto &liveOperations = contextRef->liveOperations;
796   auto it = liveOperations.find(operation.ptr);
797   if (it == liveOperations.end()) {
798     // Create.
799     return createInstance(std::move(contextRef), operation,
800                           std::move(parentKeepAlive));
801   }
802   // Use existing.
803   PyOperation *existing = it->second.second;
804   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805   return PyOperationRef(existing, std::move(pyRef));
806 }
807 
808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809                                            MlirOperation operation,
810                                            py::object parentKeepAlive) {
811   auto &liveOperations = contextRef->liveOperations;
812   assert(liveOperations.count(operation.ptr) == 0 &&
813          "cannot create detached operation that already exists");
814   (void)liveOperations;
815 
816   PyOperationRef created = createInstance(std::move(contextRef), operation,
817                                           std::move(parentKeepAlive));
818   created->attached = false;
819   return created;
820 }
821 
822 void PyOperation::checkValid() const {
823   if (!valid) {
824     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825   }
826 }
827 
828 void PyOperationBase::print(py::object fileObject, bool binary,
829                             llvm::Optional<int64_t> largeElementsLimit,
830                             bool enableDebugInfo, bool prettyDebugInfo,
831                             bool printGenericOpForm, bool useLocalScope) {
832   PyOperation &operation = getOperation();
833   operation.checkValid();
834   if (fileObject.is_none())
835     fileObject = py::module::import("sys").attr("stdout");
836 
837   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838     fileObject.attr("write")("// Verification failed, printing generic form\n");
839     printGenericOpForm = true;
840   }
841 
842   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843   if (largeElementsLimit)
844     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845   if (enableDebugInfo)
846     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847   if (printGenericOpForm)
848     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849 
850   PyFileAccumulator accum(fileObject, binary);
851   py::gil_scoped_release();
852   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853                               accum.getUserData());
854   mlirOpPrintingFlagsDestroy(flags);
855 }
856 
857 py::object PyOperationBase::getAsm(bool binary,
858                                    llvm::Optional<int64_t> largeElementsLimit,
859                                    bool enableDebugInfo, bool prettyDebugInfo,
860                                    bool printGenericOpForm,
861                                    bool useLocalScope) {
862   py::object fileObject;
863   if (binary) {
864     fileObject = py::module::import("io").attr("BytesIO")();
865   } else {
866     fileObject = py::module::import("io").attr("StringIO")();
867   }
868   print(fileObject, /*binary=*/binary,
869         /*largeElementsLimit=*/largeElementsLimit,
870         /*enableDebugInfo=*/enableDebugInfo,
871         /*prettyDebugInfo=*/prettyDebugInfo,
872         /*printGenericOpForm=*/printGenericOpForm,
873         /*useLocalScope=*/useLocalScope);
874 
875   return fileObject.attr("getvalue")();
876 }
877 
878 void PyOperationBase::moveAfter(PyOperationBase &other) {
879   PyOperation &operation = getOperation();
880   PyOperation &otherOp = other.getOperation();
881   operation.checkValid();
882   otherOp.checkValid();
883   mlirOperationMoveAfter(operation, otherOp);
884   operation.parentKeepAlive = otherOp.parentKeepAlive;
885 }
886 
887 void PyOperationBase::moveBefore(PyOperationBase &other) {
888   PyOperation &operation = getOperation();
889   PyOperation &otherOp = other.getOperation();
890   operation.checkValid();
891   otherOp.checkValid();
892   mlirOperationMoveBefore(operation, otherOp);
893   operation.parentKeepAlive = otherOp.parentKeepAlive;
894 }
895 
896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
897   checkValid();
898   if (!isAttached())
899     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
900   MlirOperation operation = mlirOperationGetParentOperation(get());
901   if (mlirOperationIsNull(operation))
902     return {};
903   return PyOperation::forOperation(getContext(), operation);
904 }
905 
906 PyBlock PyOperation::getBlock() {
907   checkValid();
908   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
909   MlirBlock block = mlirOperationGetBlock(get());
910   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
911   assert(parentOperation && "Operation has no parent");
912   return PyBlock{std::move(*parentOperation), block};
913 }
914 
915 py::object PyOperation::getCapsule() {
916   checkValid();
917   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
918 }
919 
920 py::object PyOperation::createFromCapsule(py::object capsule) {
921   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
922   if (mlirOperationIsNull(rawOperation))
923     throw py::error_already_set();
924   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
925   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
926       .releaseObject();
927 }
928 
929 py::object PyOperation::create(
930     std::string name, llvm::Optional<std::vector<PyType *>> results,
931     llvm::Optional<std::vector<PyValue *>> operands,
932     llvm::Optional<py::dict> attributes,
933     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
934     DefaultingPyLocation location, py::object maybeIp) {
935   llvm::SmallVector<MlirValue, 4> mlirOperands;
936   llvm::SmallVector<MlirType, 4> mlirResults;
937   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
938   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
939 
940   // General parameter validation.
941   if (regions < 0)
942     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
943 
944   // Unpack/validate operands.
945   if (operands) {
946     mlirOperands.reserve(operands->size());
947     for (PyValue *operand : *operands) {
948       if (!operand)
949         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
950       mlirOperands.push_back(operand->get());
951     }
952   }
953 
954   // Unpack/validate results.
955   if (results) {
956     mlirResults.reserve(results->size());
957     for (PyType *result : *results) {
958       // TODO: Verify result type originate from the same context.
959       if (!result)
960         throw SetPyError(PyExc_ValueError, "result type cannot be None");
961       mlirResults.push_back(*result);
962     }
963   }
964   // Unpack/validate attributes.
965   if (attributes) {
966     mlirAttributes.reserve(attributes->size());
967     for (auto &it : *attributes) {
968       std::string key;
969       try {
970         key = it.first.cast<std::string>();
971       } catch (py::cast_error &err) {
972         std::string msg = "Invalid attribute key (not a string) when "
973                           "attempting to create the operation \"" +
974                           name + "\" (" + err.what() + ")";
975         throw py::cast_error(msg);
976       }
977       try {
978         auto &attribute = it.second.cast<PyAttribute &>();
979         // TODO: Verify attribute originates from the same context.
980         mlirAttributes.emplace_back(std::move(key), attribute);
981       } catch (py::reference_cast_error &) {
982         // This exception seems thrown when the value is "None".
983         std::string msg =
984             "Found an invalid (`None`?) attribute value for the key \"" + key +
985             "\" when attempting to create the operation \"" + name + "\"";
986         throw py::cast_error(msg);
987       } catch (py::cast_error &err) {
988         std::string msg = "Invalid attribute value for the key \"" + key +
989                           "\" when attempting to create the operation \"" +
990                           name + "\" (" + err.what() + ")";
991         throw py::cast_error(msg);
992       }
993     }
994   }
995   // Unpack/validate successors.
996   if (successors) {
997     mlirSuccessors.reserve(successors->size());
998     for (auto *successor : *successors) {
999       // TODO: Verify successor originate from the same context.
1000       if (!successor)
1001         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1002       mlirSuccessors.push_back(successor->get());
1003     }
1004   }
1005 
1006   // Apply unpacked/validated to the operation state. Beyond this
1007   // point, exceptions cannot be thrown or else the state will leak.
1008   MlirOperationState state =
1009       mlirOperationStateGet(toMlirStringRef(name), location);
1010   if (!mlirOperands.empty())
1011     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1012                                   mlirOperands.data());
1013   if (!mlirResults.empty())
1014     mlirOperationStateAddResults(&state, mlirResults.size(),
1015                                  mlirResults.data());
1016   if (!mlirAttributes.empty()) {
1017     // Note that the attribute names directly reference bytes in
1018     // mlirAttributes, so that vector must not be changed from here
1019     // on.
1020     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1021     mlirNamedAttributes.reserve(mlirAttributes.size());
1022     for (auto &it : mlirAttributes)
1023       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1024           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1025                             toMlirStringRef(it.first)),
1026           it.second));
1027     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1028                                     mlirNamedAttributes.data());
1029   }
1030   if (!mlirSuccessors.empty())
1031     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1032                                     mlirSuccessors.data());
1033   if (regions) {
1034     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1035     mlirRegions.resize(regions);
1036     for (int i = 0; i < regions; ++i)
1037       mlirRegions[i] = mlirRegionCreate();
1038     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1039                                       mlirRegions.data());
1040   }
1041 
1042   // Construct the operation.
1043   MlirOperation operation = mlirOperationCreate(&state);
1044   PyOperationRef created =
1045       PyOperation::createDetached(location->getContext(), operation);
1046 
1047   // InsertPoint active?
1048   if (!maybeIp.is(py::cast(false))) {
1049     PyInsertionPoint *ip;
1050     if (maybeIp.is_none()) {
1051       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1052     } else {
1053       ip = py::cast<PyInsertionPoint *>(maybeIp);
1054     }
1055     if (ip)
1056       ip->insert(*created.get());
1057   }
1058 
1059   return created->createOpView();
1060 }
1061 
1062 py::object PyOperation::createOpView() {
1063   checkValid();
1064   MlirIdentifier ident = mlirOperationGetName(get());
1065   MlirStringRef identStr = mlirIdentifierStr(ident);
1066   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1067       StringRef(identStr.data, identStr.length));
1068   if (opViewClass)
1069     return (*opViewClass)(getRef().getObject());
1070   return py::cast(PyOpView(getRef().getObject()));
1071 }
1072 
1073 void PyOperation::erase() {
1074   checkValid();
1075   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1076   // Python reference to a child operation is live. All children should also
1077   // have their `valid` bit set to false.
1078   auto &liveOperations = getContext()->liveOperations;
1079   if (liveOperations.count(operation.ptr))
1080     liveOperations.erase(operation.ptr);
1081   mlirOperationDestroy(operation);
1082   valid = false;
1083 }
1084 
1085 //------------------------------------------------------------------------------
1086 // PyOpView
1087 //------------------------------------------------------------------------------
1088 
1089 py::object
1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1091                        py::list operandList,
1092                        llvm::Optional<py::dict> attributes,
1093                        llvm::Optional<std::vector<PyBlock *>> successors,
1094                        llvm::Optional<int> regions,
1095                        DefaultingPyLocation location, py::object maybeIp) {
1096   PyMlirContextRef context = location->getContext();
1097   // Class level operation construction metadata.
1098   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1099   // Operand and result segment specs are either none, which does no
1100   // variadic unpacking, or a list of ints with segment sizes, where each
1101   // element is either a positive number (typically 1 for a scalar) or -1 to
1102   // indicate that it is derived from the length of the same-indexed operand
1103   // or result (implying that it is a list at that position).
1104   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1105   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1106 
1107   std::vector<uint32_t> operandSegmentLengths;
1108   std::vector<uint32_t> resultSegmentLengths;
1109 
1110   // Validate/determine region count.
1111   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1112   int opMinRegionCount = std::get<0>(opRegionSpec);
1113   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1114   if (!regions) {
1115     regions = opMinRegionCount;
1116   }
1117   if (*regions < opMinRegionCount) {
1118     throw py::value_error(
1119         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1120          llvm::Twine(opMinRegionCount) +
1121          " regions but was built with regions=" + llvm::Twine(*regions))
1122             .str());
1123   }
1124   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1125     throw py::value_error(
1126         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1127          llvm::Twine(opMinRegionCount) +
1128          " regions but was built with regions=" + llvm::Twine(*regions))
1129             .str());
1130   }
1131 
1132   // Unpack results.
1133   std::vector<PyType *> resultTypes;
1134   resultTypes.reserve(resultTypeList.size());
1135   if (resultSegmentSpecObj.is_none()) {
1136     // Non-variadic result unpacking.
1137     for (auto it : llvm::enumerate(resultTypeList)) {
1138       try {
1139         resultTypes.push_back(py::cast<PyType *>(it.value()));
1140         if (!resultTypes.back())
1141           throw py::cast_error();
1142       } catch (py::cast_error &err) {
1143         throw py::value_error((llvm::Twine("Result ") +
1144                                llvm::Twine(it.index()) + " of operation \"" +
1145                                name + "\" must be a Type (" + err.what() + ")")
1146                                   .str());
1147       }
1148     }
1149   } else {
1150     // Sized result unpacking.
1151     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1152     if (resultSegmentSpec.size() != resultTypeList.size()) {
1153       throw py::value_error((llvm::Twine("Operation \"") + name +
1154                              "\" requires " +
1155                              llvm::Twine(resultSegmentSpec.size()) +
1156                              "result segments but was provided " +
1157                              llvm::Twine(resultTypeList.size()))
1158                                 .str());
1159     }
1160     resultSegmentLengths.reserve(resultTypeList.size());
1161     for (auto it :
1162          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1163       int segmentSpec = std::get<1>(it.value());
1164       if (segmentSpec == 1 || segmentSpec == 0) {
1165         // Unpack unary element.
1166         try {
1167           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1168           if (resultType) {
1169             resultTypes.push_back(resultType);
1170             resultSegmentLengths.push_back(1);
1171           } else if (segmentSpec == 0) {
1172             // Allowed to be optional.
1173             resultSegmentLengths.push_back(0);
1174           } else {
1175             throw py::cast_error("was None and result is not optional");
1176           }
1177         } catch (py::cast_error &err) {
1178           throw py::value_error((llvm::Twine("Result ") +
1179                                  llvm::Twine(it.index()) + " of operation \"" +
1180                                  name + "\" must be a Type (" + err.what() +
1181                                  ")")
1182                                     .str());
1183         }
1184       } else if (segmentSpec == -1) {
1185         // Unpack sequence by appending.
1186         try {
1187           if (std::get<0>(it.value()).is_none()) {
1188             // Treat it as an empty list.
1189             resultSegmentLengths.push_back(0);
1190           } else {
1191             // Unpack the list.
1192             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1193             for (py::object segmentItem : segment) {
1194               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1195               if (!resultTypes.back()) {
1196                 throw py::cast_error("contained a None item");
1197               }
1198             }
1199             resultSegmentLengths.push_back(segment.size());
1200           }
1201         } catch (std::exception &err) {
1202           // NOTE: Sloppy to be using a catch-all here, but there are at least
1203           // three different unrelated exceptions that can be thrown in the
1204           // above "casts". Just keep the scope above small and catch them all.
1205           throw py::value_error((llvm::Twine("Result ") +
1206                                  llvm::Twine(it.index()) + " of operation \"" +
1207                                  name + "\" must be a Sequence of Types (" +
1208                                  err.what() + ")")
1209                                     .str());
1210         }
1211       } else {
1212         throw py::value_error("Unexpected segment spec");
1213       }
1214     }
1215   }
1216 
1217   // Unpack operands.
1218   std::vector<PyValue *> operands;
1219   operands.reserve(operands.size());
1220   if (operandSegmentSpecObj.is_none()) {
1221     // Non-sized operand unpacking.
1222     for (auto it : llvm::enumerate(operandList)) {
1223       try {
1224         operands.push_back(py::cast<PyValue *>(it.value()));
1225         if (!operands.back())
1226           throw py::cast_error();
1227       } catch (py::cast_error &err) {
1228         throw py::value_error((llvm::Twine("Operand ") +
1229                                llvm::Twine(it.index()) + " of operation \"" +
1230                                name + "\" must be a Value (" + err.what() + ")")
1231                                   .str());
1232       }
1233     }
1234   } else {
1235     // Sized operand unpacking.
1236     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1237     if (operandSegmentSpec.size() != operandList.size()) {
1238       throw py::value_error((llvm::Twine("Operation \"") + name +
1239                              "\" requires " +
1240                              llvm::Twine(operandSegmentSpec.size()) +
1241                              "operand segments but was provided " +
1242                              llvm::Twine(operandList.size()))
1243                                 .str());
1244     }
1245     operandSegmentLengths.reserve(operandList.size());
1246     for (auto it :
1247          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1248       int segmentSpec = std::get<1>(it.value());
1249       if (segmentSpec == 1 || segmentSpec == 0) {
1250         // Unpack unary element.
1251         try {
1252           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1253           if (operandValue) {
1254             operands.push_back(operandValue);
1255             operandSegmentLengths.push_back(1);
1256           } else if (segmentSpec == 0) {
1257             // Allowed to be optional.
1258             operandSegmentLengths.push_back(0);
1259           } else {
1260             throw py::cast_error("was None and operand is not optional");
1261           }
1262         } catch (py::cast_error &err) {
1263           throw py::value_error((llvm::Twine("Operand ") +
1264                                  llvm::Twine(it.index()) + " of operation \"" +
1265                                  name + "\" must be a Value (" + err.what() +
1266                                  ")")
1267                                     .str());
1268         }
1269       } else if (segmentSpec == -1) {
1270         // Unpack sequence by appending.
1271         try {
1272           if (std::get<0>(it.value()).is_none()) {
1273             // Treat it as an empty list.
1274             operandSegmentLengths.push_back(0);
1275           } else {
1276             // Unpack the list.
1277             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1278             for (py::object segmentItem : segment) {
1279               operands.push_back(py::cast<PyValue *>(segmentItem));
1280               if (!operands.back()) {
1281                 throw py::cast_error("contained a None item");
1282               }
1283             }
1284             operandSegmentLengths.push_back(segment.size());
1285           }
1286         } catch (std::exception &err) {
1287           // NOTE: Sloppy to be using a catch-all here, but there are at least
1288           // three different unrelated exceptions that can be thrown in the
1289           // above "casts". Just keep the scope above small and catch them all.
1290           throw py::value_error((llvm::Twine("Operand ") +
1291                                  llvm::Twine(it.index()) + " of operation \"" +
1292                                  name + "\" must be a Sequence of Values (" +
1293                                  err.what() + ")")
1294                                     .str());
1295         }
1296       } else {
1297         throw py::value_error("Unexpected segment spec");
1298       }
1299     }
1300   }
1301 
1302   // Merge operand/result segment lengths into attributes if needed.
1303   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1304     // Dup.
1305     if (attributes) {
1306       attributes = py::dict(*attributes);
1307     } else {
1308       attributes = py::dict();
1309     }
1310     if (attributes->contains("result_segment_sizes") ||
1311         attributes->contains("operand_segment_sizes")) {
1312       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1313                             "'operand_segment_sizes' attribute is unsupported. "
1314                             "Use Operation.create for such low-level access.");
1315     }
1316 
1317     // Add result_segment_sizes attribute.
1318     if (!resultSegmentLengths.empty()) {
1319       int64_t size = resultSegmentLengths.size();
1320       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1321           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1322           resultSegmentLengths.size(), resultSegmentLengths.data());
1323       (*attributes)["result_segment_sizes"] =
1324           PyAttribute(context, segmentLengthAttr);
1325     }
1326 
1327     // Add operand_segment_sizes attribute.
1328     if (!operandSegmentLengths.empty()) {
1329       int64_t size = operandSegmentLengths.size();
1330       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1331           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1332           operandSegmentLengths.size(), operandSegmentLengths.data());
1333       (*attributes)["operand_segment_sizes"] =
1334           PyAttribute(context, segmentLengthAttr);
1335     }
1336   }
1337 
1338   // Delegate to create.
1339   return PyOperation::create(std::move(name),
1340                              /*results=*/std::move(resultTypes),
1341                              /*operands=*/std::move(operands),
1342                              /*attributes=*/std::move(attributes),
1343                              /*successors=*/std::move(successors),
1344                              /*regions=*/*regions, location, maybeIp);
1345 }
1346 
1347 PyOpView::PyOpView(py::object operationObject)
1348     // Casting through the PyOperationBase base-class and then back to the
1349     // Operation lets us accept any PyOperationBase subclass.
1350     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1351       operationObject(operation.getRef().getObject()) {}
1352 
1353 py::object PyOpView::createRawSubclass(py::object userClass) {
1354   // This is... a little gross. The typical pattern is to have a pure python
1355   // class that extends OpView like:
1356   //   class AddFOp(_cext.ir.OpView):
1357   //     def __init__(self, loc, lhs, rhs):
1358   //       operation = loc.context.create_operation(
1359   //           "addf", lhs, rhs, results=[lhs.type])
1360   //       super().__init__(operation)
1361   //
1362   // I.e. The goal of the user facing type is to provide a nice constructor
1363   // that has complete freedom for the op under construction. This is at odds
1364   // with our other desire to sometimes create this object by just passing an
1365   // operation (to initialize the base class). We could do *arg and **kwargs
1366   // munging to try to make it work, but instead, we synthesize a new class
1367   // on the fly which extends this user class (AddFOp in this example) and
1368   // *give it* the base class's __init__ method, thus bypassing the
1369   // intermediate subclass's __init__ method entirely. While slightly,
1370   // underhanded, this is safe/legal because the type hierarchy has not changed
1371   // (we just added a new leaf) and we aren't mucking around with __new__.
1372   // Typically, this new class will be stored on the original as "_Raw" and will
1373   // be used for casts and other things that need a variant of the class that
1374   // is initialized purely from an operation.
1375   py::object parentMetaclass =
1376       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1377   py::dict attributes;
1378   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1379   // now.
1380   //   auto opViewType = py::type::of<PyOpView>();
1381   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1382   attributes["__init__"] = opViewType.attr("__init__");
1383   py::str origName = userClass.attr("__name__");
1384   py::str newName = py::str("_") + origName;
1385   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1386 }
1387 
1388 //------------------------------------------------------------------------------
1389 // PyInsertionPoint.
1390 //------------------------------------------------------------------------------
1391 
1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1393 
1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1395     : refOperation(beforeOperationBase.getOperation().getRef()),
1396       block((*refOperation)->getBlock()) {}
1397 
1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1399   PyOperation &operation = operationBase.getOperation();
1400   if (operation.isAttached())
1401     throw SetPyError(PyExc_ValueError,
1402                      "Attempt to insert operation that is already attached");
1403   block.getParentOperation()->checkValid();
1404   MlirOperation beforeOp = {nullptr};
1405   if (refOperation) {
1406     // Insert before operation.
1407     (*refOperation)->checkValid();
1408     beforeOp = (*refOperation)->get();
1409   } else {
1410     // Insert at end (before null) is only valid if the block does not
1411     // already end in a known terminator (violating this will cause assertion
1412     // failures later).
1413     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1414       throw py::index_error("Cannot insert operation at the end of a block "
1415                             "that already has a terminator. Did you mean to "
1416                             "use 'InsertionPoint.at_block_terminator(block)' "
1417                             "versus 'InsertionPoint(block)'?");
1418     }
1419   }
1420   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1421   operation.setAttached();
1422 }
1423 
1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1425   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1426   if (mlirOperationIsNull(firstOp)) {
1427     // Just insert at end.
1428     return PyInsertionPoint(block);
1429   }
1430 
1431   // Insert before first op.
1432   PyOperationRef firstOpRef = PyOperation::forOperation(
1433       block.getParentOperation()->getContext(), firstOp);
1434   return PyInsertionPoint{block, std::move(firstOpRef)};
1435 }
1436 
1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1438   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1439   if (mlirOperationIsNull(terminator))
1440     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1441   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1442       block.getParentOperation()->getContext(), terminator);
1443   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1444 }
1445 
1446 py::object PyInsertionPoint::contextEnter() {
1447   return PyThreadContextEntry::pushInsertionPoint(*this);
1448 }
1449 
1450 void PyInsertionPoint::contextExit(pybind11::object excType,
1451                                    pybind11::object excVal,
1452                                    pybind11::object excTb) {
1453   PyThreadContextEntry::popInsertionPoint(*this);
1454 }
1455 
1456 //------------------------------------------------------------------------------
1457 // PyAttribute.
1458 //------------------------------------------------------------------------------
1459 
1460 bool PyAttribute::operator==(const PyAttribute &other) {
1461   return mlirAttributeEqual(attr, other.attr);
1462 }
1463 
1464 py::object PyAttribute::getCapsule() {
1465   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1466 }
1467 
1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1469   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1470   if (mlirAttributeIsNull(rawAttr))
1471     throw py::error_already_set();
1472   return PyAttribute(
1473       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1474 }
1475 
1476 //------------------------------------------------------------------------------
1477 // PyNamedAttribute.
1478 //------------------------------------------------------------------------------
1479 
1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1481     : ownedName(new std::string(std::move(ownedName))) {
1482   namedAttr = mlirNamedAttributeGet(
1483       mlirIdentifierGet(mlirAttributeGetContext(attr),
1484                         toMlirStringRef(*this->ownedName)),
1485       attr);
1486 }
1487 
1488 //------------------------------------------------------------------------------
1489 // PyType.
1490 //------------------------------------------------------------------------------
1491 
1492 bool PyType::operator==(const PyType &other) {
1493   return mlirTypeEqual(type, other.type);
1494 }
1495 
1496 py::object PyType::getCapsule() {
1497   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1498 }
1499 
1500 PyType PyType::createFromCapsule(py::object capsule) {
1501   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1502   if (mlirTypeIsNull(rawType))
1503     throw py::error_already_set();
1504   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1505                 rawType);
1506 }
1507 
1508 //------------------------------------------------------------------------------
1509 // PyValue and subclases.
1510 //------------------------------------------------------------------------------
1511 
1512 pybind11::object PyValue::getCapsule() {
1513   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1514 }
1515 
1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1517   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1518   if (mlirValueIsNull(value))
1519     throw py::error_already_set();
1520   MlirOperation owner;
1521   if (mlirValueIsAOpResult(value))
1522     owner = mlirOpResultGetOwner(value);
1523   if (mlirValueIsABlockArgument(value))
1524     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1525   if (mlirOperationIsNull(owner))
1526     throw py::error_already_set();
1527   MlirContext ctx = mlirOperationGetContext(owner);
1528   PyOperationRef ownerRef =
1529       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1530   return PyValue(ownerRef, value);
1531 }
1532 
1533 namespace {
1534 /// CRTP base class for Python MLIR values that subclass Value and should be
1535 /// castable from it. The value hierarchy is one level deep and is not supposed
1536 /// to accommodate other levels unless core MLIR changes.
1537 template <typename DerivedTy>
1538 class PyConcreteValue : public PyValue {
1539 public:
1540   // Derived classes must define statics for:
1541   //   IsAFunctionTy isaFunction
1542   //   const char *pyClassName
1543   // and redefine bindDerived.
1544   using ClassTy = py::class_<DerivedTy, PyValue>;
1545   using IsAFunctionTy = bool (*)(MlirValue);
1546 
1547   PyConcreteValue() = default;
1548   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1549       : PyValue(operationRef, value) {}
1550   PyConcreteValue(PyValue &orig)
1551       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1552 
1553   /// Attempts to cast the original value to the derived type and throws on
1554   /// type mismatches.
1555   static MlirValue castFrom(PyValue &orig) {
1556     if (!DerivedTy::isaFunction(orig.get())) {
1557       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1558       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1559                                              DerivedTy::pyClassName +
1560                                              " (from " + origRepr + ")");
1561     }
1562     return orig.get();
1563   }
1564 
1565   /// Binds the Python module objects to functions of this class.
1566   static void bind(py::module &m) {
1567     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1568     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1569     cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
1570       return DerivedTy::isaFunction(otherValue);
1571     });
1572     DerivedTy::bindDerived(cls);
1573   }
1574 
1575   /// Implemented by derived classes to add methods to the Python subclass.
1576   static void bindDerived(ClassTy &m) {}
1577 };
1578 
1579 /// Python wrapper for MlirBlockArgument.
1580 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1581 public:
1582   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1583   static constexpr const char *pyClassName = "BlockArgument";
1584   using PyConcreteValue::PyConcreteValue;
1585 
1586   static void bindDerived(ClassTy &c) {
1587     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1588       return PyBlock(self.getParentOperation(),
1589                      mlirBlockArgumentGetOwner(self.get()));
1590     });
1591     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1592       return mlirBlockArgumentGetArgNumber(self.get());
1593     });
1594     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1595       return mlirBlockArgumentSetType(self.get(), type);
1596     });
1597   }
1598 };
1599 
1600 /// Python wrapper for MlirOpResult.
1601 class PyOpResult : public PyConcreteValue<PyOpResult> {
1602 public:
1603   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1604   static constexpr const char *pyClassName = "OpResult";
1605   using PyConcreteValue::PyConcreteValue;
1606 
1607   static void bindDerived(ClassTy &c) {
1608     c.def_property_readonly("owner", [](PyOpResult &self) {
1609       assert(
1610           mlirOperationEqual(self.getParentOperation()->get(),
1611                              mlirOpResultGetOwner(self.get())) &&
1612           "expected the owner of the value in Python to match that in the IR");
1613       return self.getParentOperation().getObject();
1614     });
1615     c.def_property_readonly("result_number", [](PyOpResult &self) {
1616       return mlirOpResultGetResultNumber(self.get());
1617     });
1618   }
1619 };
1620 
1621 /// Returns the list of types of the values held by container.
1622 template <typename Container>
1623 static std::vector<PyType> getValueTypes(Container &container,
1624                                          PyMlirContextRef &context) {
1625   std::vector<PyType> result;
1626   result.reserve(container.getNumElements());
1627   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1628     result.push_back(
1629         PyType(context, mlirValueGetType(container.getElement(i).get())));
1630   }
1631   return result;
1632 }
1633 
1634 /// A list of block arguments. Internally, these are stored as consecutive
1635 /// elements, random access is cheap. The argument list is associated with the
1636 /// operation that contains the block (detached blocks are not allowed in
1637 /// Python bindings) and extends its lifetime.
1638 class PyBlockArgumentList
1639     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1640 public:
1641   static constexpr const char *pyClassName = "BlockArgumentList";
1642 
1643   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1644                       intptr_t startIndex = 0, intptr_t length = -1,
1645                       intptr_t step = 1)
1646       : Sliceable(startIndex,
1647                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1648                   step),
1649         operation(std::move(operation)), block(block) {}
1650 
1651   /// Returns the number of arguments in the list.
1652   intptr_t getNumElements() {
1653     operation->checkValid();
1654     return mlirBlockGetNumArguments(block);
1655   }
1656 
1657   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1658   PyBlockArgument getElement(intptr_t pos) {
1659     MlirValue argument = mlirBlockGetArgument(block, pos);
1660     return PyBlockArgument(operation, argument);
1661   }
1662 
1663   /// Returns a sublist of this list.
1664   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1665                             intptr_t step) {
1666     return PyBlockArgumentList(operation, block, startIndex, length, step);
1667   }
1668 
1669   static void bindDerived(ClassTy &c) {
1670     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1671       return getValueTypes(self, self.operation->getContext());
1672     });
1673   }
1674 
1675 private:
1676   PyOperationRef operation;
1677   MlirBlock block;
1678 };
1679 
1680 /// A list of operation operands. Internally, these are stored as consecutive
1681 /// elements, random access is cheap. The result list is associated with the
1682 /// operation whose results these are, and extends the lifetime of this
1683 /// operation.
1684 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1685 public:
1686   static constexpr const char *pyClassName = "OpOperandList";
1687 
1688   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1689                   intptr_t length = -1, intptr_t step = 1)
1690       : Sliceable(startIndex,
1691                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1692                                : length,
1693                   step),
1694         operation(operation) {}
1695 
1696   intptr_t getNumElements() {
1697     operation->checkValid();
1698     return mlirOperationGetNumOperands(operation->get());
1699   }
1700 
1701   PyValue getElement(intptr_t pos) {
1702     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1703     MlirOperation owner;
1704     if (mlirValueIsAOpResult(operand))
1705       owner = mlirOpResultGetOwner(operand);
1706     else if (mlirValueIsABlockArgument(operand))
1707       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1708     else
1709       assert(false && "Value must be an block arg or op result.");
1710     PyOperationRef pyOwner =
1711         PyOperation::forOperation(operation->getContext(), owner);
1712     return PyValue(pyOwner, operand);
1713   }
1714 
1715   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1716     return PyOpOperandList(operation, startIndex, length, step);
1717   }
1718 
1719   void dunderSetItem(intptr_t index, PyValue value) {
1720     index = wrapIndex(index);
1721     mlirOperationSetOperand(operation->get(), index, value.get());
1722   }
1723 
1724   static void bindDerived(ClassTy &c) {
1725     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1726   }
1727 
1728 private:
1729   PyOperationRef operation;
1730 };
1731 
1732 /// A list of operation results. Internally, these are stored as consecutive
1733 /// elements, random access is cheap. The result list is associated with the
1734 /// operation whose results these are, and extends the lifetime of this
1735 /// operation.
1736 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1737 public:
1738   static constexpr const char *pyClassName = "OpResultList";
1739 
1740   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1741                  intptr_t length = -1, intptr_t step = 1)
1742       : Sliceable(startIndex,
1743                   length == -1 ? mlirOperationGetNumResults(operation->get())
1744                                : length,
1745                   step),
1746         operation(operation) {}
1747 
1748   intptr_t getNumElements() {
1749     operation->checkValid();
1750     return mlirOperationGetNumResults(operation->get());
1751   }
1752 
1753   PyOpResult getElement(intptr_t index) {
1754     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1755     return PyOpResult(value);
1756   }
1757 
1758   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1759     return PyOpResultList(operation, startIndex, length, step);
1760   }
1761 
1762   static void bindDerived(ClassTy &c) {
1763     c.def_property_readonly("types", [](PyOpResultList &self) {
1764       return getValueTypes(self, self.operation->getContext());
1765     });
1766   }
1767 
1768 private:
1769   PyOperationRef operation;
1770 };
1771 
1772 /// A list of operation attributes. Can be indexed by name, producing
1773 /// attributes, or by index, producing named attributes.
1774 class PyOpAttributeMap {
1775 public:
1776   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1777 
1778   PyAttribute dunderGetItemNamed(const std::string &name) {
1779     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1780                                                          toMlirStringRef(name));
1781     if (mlirAttributeIsNull(attr)) {
1782       throw SetPyError(PyExc_KeyError,
1783                        "attempt to access a non-existent attribute");
1784     }
1785     return PyAttribute(operation->getContext(), attr);
1786   }
1787 
1788   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1789     if (index < 0 || index >= dunderLen()) {
1790       throw SetPyError(PyExc_IndexError,
1791                        "attempt to access out of bounds attribute");
1792     }
1793     MlirNamedAttribute namedAttr =
1794         mlirOperationGetAttribute(operation->get(), index);
1795     return PyNamedAttribute(
1796         namedAttr.attribute,
1797         std::string(mlirIdentifierStr(namedAttr.name).data));
1798   }
1799 
1800   void dunderSetItem(const std::string &name, PyAttribute attr) {
1801     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1802                                     attr);
1803   }
1804 
1805   void dunderDelItem(const std::string &name) {
1806     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1807                                                      toMlirStringRef(name));
1808     if (!removed)
1809       throw SetPyError(PyExc_KeyError,
1810                        "attempt to delete a non-existent attribute");
1811   }
1812 
1813   intptr_t dunderLen() {
1814     return mlirOperationGetNumAttributes(operation->get());
1815   }
1816 
1817   bool dunderContains(const std::string &name) {
1818     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1819         operation->get(), toMlirStringRef(name)));
1820   }
1821 
1822   static void bind(py::module &m) {
1823     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1824         .def("__contains__", &PyOpAttributeMap::dunderContains)
1825         .def("__len__", &PyOpAttributeMap::dunderLen)
1826         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1827         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1828         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1829         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1830   }
1831 
1832 private:
1833   PyOperationRef operation;
1834 };
1835 
1836 } // end namespace
1837 
1838 //------------------------------------------------------------------------------
1839 // Populates the core exports of the 'ir' submodule.
1840 //------------------------------------------------------------------------------
1841 
1842 void mlir::python::populateIRCore(py::module &m) {
1843   //----------------------------------------------------------------------------
1844   // Mapping of MlirContext.
1845   //----------------------------------------------------------------------------
1846   py::class_<PyMlirContext>(m, "Context", py::module_local())
1847       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1848       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1849       .def("_get_context_again",
1850            [](PyMlirContext &self) {
1851              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1852              return ref.releaseObject();
1853            })
1854       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1855       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1856       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1857                              &PyMlirContext::getCapsule)
1858       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1859       .def("__enter__", &PyMlirContext::contextEnter)
1860       .def("__exit__", &PyMlirContext::contextExit)
1861       .def_property_readonly_static(
1862           "current",
1863           [](py::object & /*class*/) {
1864             auto *context = PyThreadContextEntry::getDefaultContext();
1865             if (!context)
1866               throw SetPyError(PyExc_ValueError, "No current Context");
1867             return context;
1868           },
1869           "Gets the Context bound to the current thread or raises ValueError")
1870       .def_property_readonly(
1871           "dialects",
1872           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1873           "Gets a container for accessing dialects by name")
1874       .def_property_readonly(
1875           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1876           "Alias for 'dialect'")
1877       .def(
1878           "get_dialect_descriptor",
1879           [=](PyMlirContext &self, std::string &name) {
1880             MlirDialect dialect = mlirContextGetOrLoadDialect(
1881                 self.get(), {name.data(), name.size()});
1882             if (mlirDialectIsNull(dialect)) {
1883               throw SetPyError(PyExc_ValueError,
1884                                Twine("Dialect '") + name + "' not found");
1885             }
1886             return PyDialectDescriptor(self.getRef(), dialect);
1887           },
1888           "Gets or loads a dialect by name, returning its descriptor object")
1889       .def_property(
1890           "allow_unregistered_dialects",
1891           [](PyMlirContext &self) -> bool {
1892             return mlirContextGetAllowUnregisteredDialects(self.get());
1893           },
1894           [](PyMlirContext &self, bool value) {
1895             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1896           })
1897       .def("enable_multithreading",
1898            [](PyMlirContext &self, bool enable) {
1899              mlirContextEnableMultithreading(self.get(), enable);
1900            })
1901       .def("is_registered_operation",
1902            [](PyMlirContext &self, std::string &name) {
1903              return mlirContextIsRegisteredOperation(
1904                  self.get(), MlirStringRef{name.data(), name.size()});
1905            });
1906 
1907   //----------------------------------------------------------------------------
1908   // Mapping of PyDialectDescriptor
1909   //----------------------------------------------------------------------------
1910   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1911       .def_property_readonly("namespace",
1912                              [](PyDialectDescriptor &self) {
1913                                MlirStringRef ns =
1914                                    mlirDialectGetNamespace(self.get());
1915                                return py::str(ns.data, ns.length);
1916                              })
1917       .def("__repr__", [](PyDialectDescriptor &self) {
1918         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1919         std::string repr("<DialectDescriptor ");
1920         repr.append(ns.data, ns.length);
1921         repr.append(">");
1922         return repr;
1923       });
1924 
1925   //----------------------------------------------------------------------------
1926   // Mapping of PyDialects
1927   //----------------------------------------------------------------------------
1928   py::class_<PyDialects>(m, "Dialects", py::module_local())
1929       .def("__getitem__",
1930            [=](PyDialects &self, std::string keyName) {
1931              MlirDialect dialect =
1932                  self.getDialectForKey(keyName, /*attrError=*/false);
1933              py::object descriptor =
1934                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1935              return createCustomDialectWrapper(keyName, std::move(descriptor));
1936            })
1937       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1938         MlirDialect dialect =
1939             self.getDialectForKey(attrName, /*attrError=*/true);
1940         py::object descriptor =
1941             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1942         return createCustomDialectWrapper(attrName, std::move(descriptor));
1943       });
1944 
1945   //----------------------------------------------------------------------------
1946   // Mapping of PyDialect
1947   //----------------------------------------------------------------------------
1948   py::class_<PyDialect>(m, "Dialect", py::module_local())
1949       .def(py::init<py::object>(), "descriptor")
1950       .def_property_readonly(
1951           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1952       .def("__repr__", [](py::object self) {
1953         auto clazz = self.attr("__class__");
1954         return py::str("<Dialect ") +
1955                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1956                clazz.attr("__module__") + py::str(".") +
1957                clazz.attr("__name__") + py::str(")>");
1958       });
1959 
1960   //----------------------------------------------------------------------------
1961   // Mapping of Location
1962   //----------------------------------------------------------------------------
1963   py::class_<PyLocation>(m, "Location", py::module_local())
1964       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1965       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1966       .def("__enter__", &PyLocation::contextEnter)
1967       .def("__exit__", &PyLocation::contextExit)
1968       .def("__eq__",
1969            [](PyLocation &self, PyLocation &other) -> bool {
1970              return mlirLocationEqual(self, other);
1971            })
1972       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1973       .def_property_readonly_static(
1974           "current",
1975           [](py::object & /*class*/) {
1976             auto *loc = PyThreadContextEntry::getDefaultLocation();
1977             if (!loc)
1978               throw SetPyError(PyExc_ValueError, "No current Location");
1979             return loc;
1980           },
1981           "Gets the Location bound to the current thread or raises ValueError")
1982       .def_static(
1983           "unknown",
1984           [](DefaultingPyMlirContext context) {
1985             return PyLocation(context->getRef(),
1986                               mlirLocationUnknownGet(context->get()));
1987           },
1988           py::arg("context") = py::none(),
1989           "Gets a Location representing an unknown location")
1990       .def_static(
1991           "callsite",
1992           [](PyLocation callee, const std::vector<PyLocation> &frames,
1993              DefaultingPyMlirContext context) {
1994             if (frames.empty())
1995               throw py::value_error("No caller frames provided");
1996             MlirLocation caller = frames.back().get();
1997             for (const PyLocation &frame :
1998                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
1999               caller = mlirLocationCallSiteGet(frame.get(), caller);
2000             return PyLocation(context->getRef(),
2001                               mlirLocationCallSiteGet(callee.get(), caller));
2002           },
2003           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2004           kContextGetCallSiteLocationDocstring)
2005       .def_static(
2006           "file",
2007           [](std::string filename, int line, int col,
2008              DefaultingPyMlirContext context) {
2009             return PyLocation(
2010                 context->getRef(),
2011                 mlirLocationFileLineColGet(
2012                     context->get(), toMlirStringRef(filename), line, col));
2013           },
2014           py::arg("filename"), py::arg("line"), py::arg("col"),
2015           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2016       .def_static(
2017           "name",
2018           [](std::string name, llvm::Optional<PyLocation> childLoc,
2019              DefaultingPyMlirContext context) {
2020             return PyLocation(
2021                 context->getRef(),
2022                 mlirLocationNameGet(
2023                     context->get(), toMlirStringRef(name),
2024                     childLoc ? childLoc->get()
2025                              : mlirLocationUnknownGet(context->get())));
2026           },
2027           py::arg("name"), py::arg("childLoc") = py::none(),
2028           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2029       .def_property_readonly(
2030           "context",
2031           [](PyLocation &self) { return self.getContext().getObject(); },
2032           "Context that owns the Location")
2033       .def("__repr__", [](PyLocation &self) {
2034         PyPrintAccumulator printAccum;
2035         mlirLocationPrint(self, printAccum.getCallback(),
2036                           printAccum.getUserData());
2037         return printAccum.join();
2038       });
2039 
2040   //----------------------------------------------------------------------------
2041   // Mapping of Module
2042   //----------------------------------------------------------------------------
2043   py::class_<PyModule>(m, "Module", py::module_local())
2044       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2045       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2046       .def_static(
2047           "parse",
2048           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2049             MlirModule module = mlirModuleCreateParse(
2050                 context->get(), toMlirStringRef(moduleAsm));
2051             // TODO: Rework error reporting once diagnostic engine is exposed
2052             // in C API.
2053             if (mlirModuleIsNull(module)) {
2054               throw SetPyError(
2055                   PyExc_ValueError,
2056                   "Unable to parse module assembly (see diagnostics)");
2057             }
2058             return PyModule::forModule(module).releaseObject();
2059           },
2060           py::arg("asm"), py::arg("context") = py::none(),
2061           kModuleParseDocstring)
2062       .def_static(
2063           "create",
2064           [](DefaultingPyLocation loc) {
2065             MlirModule module = mlirModuleCreateEmpty(loc);
2066             return PyModule::forModule(module).releaseObject();
2067           },
2068           py::arg("loc") = py::none(), "Creates an empty module")
2069       .def_property_readonly(
2070           "context",
2071           [](PyModule &self) { return self.getContext().getObject(); },
2072           "Context that created the Module")
2073       .def_property_readonly(
2074           "operation",
2075           [](PyModule &self) {
2076             return PyOperation::forOperation(self.getContext(),
2077                                              mlirModuleGetOperation(self.get()),
2078                                              self.getRef().releaseObject())
2079                 .releaseObject();
2080           },
2081           "Accesses the module as an operation")
2082       .def_property_readonly(
2083           "body",
2084           [](PyModule &self) {
2085             PyOperationRef module_op = PyOperation::forOperation(
2086                 self.getContext(), mlirModuleGetOperation(self.get()),
2087                 self.getRef().releaseObject());
2088             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2089             return returnBlock;
2090           },
2091           "Return the block for this module")
2092       .def(
2093           "dump",
2094           [](PyModule &self) {
2095             mlirOperationDump(mlirModuleGetOperation(self.get()));
2096           },
2097           kDumpDocstring)
2098       .def(
2099           "__str__",
2100           [](PyModule &self) {
2101             MlirOperation operation = mlirModuleGetOperation(self.get());
2102             PyPrintAccumulator printAccum;
2103             mlirOperationPrint(operation, printAccum.getCallback(),
2104                                printAccum.getUserData());
2105             return printAccum.join();
2106           },
2107           kOperationStrDunderDocstring);
2108 
2109   //----------------------------------------------------------------------------
2110   // Mapping of Operation.
2111   //----------------------------------------------------------------------------
2112   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2113       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2114                              [](PyOperationBase &self) {
2115                                return self.getOperation().getCapsule();
2116                              })
2117       .def("__eq__",
2118            [](PyOperationBase &self, PyOperationBase &other) {
2119              return &self.getOperation() == &other.getOperation();
2120            })
2121       .def("__eq__",
2122            [](PyOperationBase &self, py::object other) { return false; })
2123       .def_property_readonly("attributes",
2124                              [](PyOperationBase &self) {
2125                                return PyOpAttributeMap(
2126                                    self.getOperation().getRef());
2127                              })
2128       .def_property_readonly("operands",
2129                              [](PyOperationBase &self) {
2130                                return PyOpOperandList(
2131                                    self.getOperation().getRef());
2132                              })
2133       .def_property_readonly("regions",
2134                              [](PyOperationBase &self) {
2135                                return PyRegionList(
2136                                    self.getOperation().getRef());
2137                              })
2138       .def_property_readonly(
2139           "results",
2140           [](PyOperationBase &self) {
2141             return PyOpResultList(self.getOperation().getRef());
2142           },
2143           "Returns the list of Operation results.")
2144       .def_property_readonly(
2145           "result",
2146           [](PyOperationBase &self) {
2147             auto &operation = self.getOperation();
2148             auto numResults = mlirOperationGetNumResults(operation);
2149             if (numResults != 1) {
2150               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2151               throw SetPyError(
2152                   PyExc_ValueError,
2153                   Twine("Cannot call .result on operation ") +
2154                       StringRef(name.data, name.length) + " which has " +
2155                       Twine(numResults) +
2156                       " results (it is only valid for operations with a "
2157                       "single result)");
2158             }
2159             return PyOpResult(operation.getRef(),
2160                               mlirOperationGetResult(operation, 0));
2161           },
2162           "Shortcut to get an op result if it has only one (throws an error "
2163           "otherwise).")
2164       .def_property_readonly(
2165           "location",
2166           [](PyOperationBase &self) {
2167             PyOperation &operation = self.getOperation();
2168             return PyLocation(operation.getContext(),
2169                               mlirOperationGetLocation(operation.get()));
2170           },
2171           "Returns the source location the operation was defined or derived "
2172           "from.")
2173       .def(
2174           "__str__",
2175           [](PyOperationBase &self) {
2176             return self.getAsm(/*binary=*/false,
2177                                /*largeElementsLimit=*/llvm::None,
2178                                /*enableDebugInfo=*/false,
2179                                /*prettyDebugInfo=*/false,
2180                                /*printGenericOpForm=*/false,
2181                                /*useLocalScope=*/false);
2182           },
2183           "Returns the assembly form of the operation.")
2184       .def("print", &PyOperationBase::print,
2185            // Careful: Lots of arguments must match up with print method.
2186            py::arg("file") = py::none(), py::arg("binary") = false,
2187            py::arg("large_elements_limit") = py::none(),
2188            py::arg("enable_debug_info") = false,
2189            py::arg("pretty_debug_info") = false,
2190            py::arg("print_generic_op_form") = false,
2191            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2192       .def("get_asm", &PyOperationBase::getAsm,
2193            // Careful: Lots of arguments must match up with get_asm method.
2194            py::arg("binary") = false,
2195            py::arg("large_elements_limit") = py::none(),
2196            py::arg("enable_debug_info") = false,
2197            py::arg("pretty_debug_info") = false,
2198            py::arg("print_generic_op_form") = false,
2199            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2200       .def(
2201           "verify",
2202           [](PyOperationBase &self) {
2203             return mlirOperationVerify(self.getOperation());
2204           },
2205           "Verify the operation and return true if it passes, false if it "
2206           "fails.")
2207       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2208            "Puts self immediately after the other operation in its parent "
2209            "block.")
2210       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2211            "Puts self immediately before the other operation in its parent "
2212            "block.")
2213       .def(
2214           "detach_from_parent",
2215           [](PyOperationBase &self) {
2216             PyOperation &operation = self.getOperation();
2217             operation.checkValid();
2218             if (!operation.isAttached())
2219               throw py::value_error("Detached operation has no parent.");
2220 
2221             operation.detachFromParent();
2222             return operation.createOpView();
2223           },
2224           "Detaches the operation from its parent block.");
2225 
2226   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2227       .def_static("create", &PyOperation::create, py::arg("name"),
2228                   py::arg("results") = py::none(),
2229                   py::arg("operands") = py::none(),
2230                   py::arg("attributes") = py::none(),
2231                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2232                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2233                   kOperationCreateDocstring)
2234       .def_property_readonly("parent",
2235                              [](PyOperation &self) -> py::object {
2236                                auto parent = self.getParentOperation();
2237                                if (parent)
2238                                  return parent->getObject();
2239                                return py::none();
2240                              })
2241       .def("erase", &PyOperation::erase)
2242       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2243                              &PyOperation::getCapsule)
2244       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2245       .def_property_readonly("name",
2246                              [](PyOperation &self) {
2247                                self.checkValid();
2248                                MlirOperation operation = self.get();
2249                                MlirStringRef name = mlirIdentifierStr(
2250                                    mlirOperationGetName(operation));
2251                                return py::str(name.data, name.length);
2252                              })
2253       .def_property_readonly(
2254           "context",
2255           [](PyOperation &self) {
2256             self.checkValid();
2257             return self.getContext().getObject();
2258           },
2259           "Context that owns the Operation")
2260       .def_property_readonly("opview", &PyOperation::createOpView);
2261 
2262   auto opViewClass =
2263       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2264           .def(py::init<py::object>())
2265           .def_property_readonly("operation", &PyOpView::getOperationObject)
2266           .def_property_readonly(
2267               "context",
2268               [](PyOpView &self) {
2269                 return self.getOperation().getContext().getObject();
2270               },
2271               "Context that owns the Operation")
2272           .def("__str__", [](PyOpView &self) {
2273             return py::str(self.getOperationObject());
2274           });
2275   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2276   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2277   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2278   opViewClass.attr("build_generic") = classmethod(
2279       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2280       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2281       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2282       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2283       "Builds a specific, generated OpView based on class level attributes.");
2284 
2285   //----------------------------------------------------------------------------
2286   // Mapping of PyRegion.
2287   //----------------------------------------------------------------------------
2288   py::class_<PyRegion>(m, "Region", py::module_local())
2289       .def_property_readonly(
2290           "blocks",
2291           [](PyRegion &self) {
2292             return PyBlockList(self.getParentOperation(), self.get());
2293           },
2294           "Returns a forward-optimized sequence of blocks.")
2295       .def_property_readonly(
2296           "owner",
2297           [](PyRegion &self) {
2298             return self.getParentOperation()->createOpView();
2299           },
2300           "Returns the operation owning this region.")
2301       .def(
2302           "__iter__",
2303           [](PyRegion &self) {
2304             self.checkValid();
2305             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2306             return PyBlockIterator(self.getParentOperation(), firstBlock);
2307           },
2308           "Iterates over blocks in the region.")
2309       .def("__eq__",
2310            [](PyRegion &self, PyRegion &other) {
2311              return self.get().ptr == other.get().ptr;
2312            })
2313       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2314 
2315   //----------------------------------------------------------------------------
2316   // Mapping of PyBlock.
2317   //----------------------------------------------------------------------------
2318   py::class_<PyBlock>(m, "Block", py::module_local())
2319       .def_property_readonly(
2320           "owner",
2321           [](PyBlock &self) {
2322             return self.getParentOperation()->createOpView();
2323           },
2324           "Returns the owning operation of this block.")
2325       .def_property_readonly(
2326           "region",
2327           [](PyBlock &self) {
2328             MlirRegion region = mlirBlockGetParentRegion(self.get());
2329             return PyRegion(self.getParentOperation(), region);
2330           },
2331           "Returns the owning region of this block.")
2332       .def_property_readonly(
2333           "arguments",
2334           [](PyBlock &self) {
2335             return PyBlockArgumentList(self.getParentOperation(), self.get());
2336           },
2337           "Returns a list of block arguments.")
2338       .def_property_readonly(
2339           "operations",
2340           [](PyBlock &self) {
2341             return PyOperationList(self.getParentOperation(), self.get());
2342           },
2343           "Returns a forward-optimized sequence of operations.")
2344       .def_static(
2345           "create_at_start",
2346           [](PyRegion &parent, py::list pyArgTypes) {
2347             parent.checkValid();
2348             llvm::SmallVector<MlirType, 4> argTypes;
2349             argTypes.reserve(pyArgTypes.size());
2350             for (auto &pyArg : pyArgTypes) {
2351               argTypes.push_back(pyArg.cast<PyType &>());
2352             }
2353 
2354             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2355             mlirRegionInsertOwnedBlock(parent, 0, block);
2356             return PyBlock(parent.getParentOperation(), block);
2357           },
2358           py::arg("parent"), py::arg("pyArgTypes") = py::list(),
2359           "Creates and returns a new Block at the beginning of the given "
2360           "region (with given argument types).")
2361       .def(
2362           "create_before",
2363           [](PyBlock &self, py::args pyArgTypes) {
2364             self.checkValid();
2365             llvm::SmallVector<MlirType, 4> argTypes;
2366             argTypes.reserve(pyArgTypes.size());
2367             for (auto &pyArg : pyArgTypes) {
2368               argTypes.push_back(pyArg.cast<PyType &>());
2369             }
2370 
2371             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2372             MlirRegion region = mlirBlockGetParentRegion(self.get());
2373             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2374             return PyBlock(self.getParentOperation(), block);
2375           },
2376           "Creates and returns a new Block before this block "
2377           "(with given argument types).")
2378       .def(
2379           "create_after",
2380           [](PyBlock &self, py::args pyArgTypes) {
2381             self.checkValid();
2382             llvm::SmallVector<MlirType, 4> argTypes;
2383             argTypes.reserve(pyArgTypes.size());
2384             for (auto &pyArg : pyArgTypes) {
2385               argTypes.push_back(pyArg.cast<PyType &>());
2386             }
2387 
2388             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2389             MlirRegion region = mlirBlockGetParentRegion(self.get());
2390             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2391             return PyBlock(self.getParentOperation(), block);
2392           },
2393           "Creates and returns a new Block after this block "
2394           "(with given argument types).")
2395       .def(
2396           "__iter__",
2397           [](PyBlock &self) {
2398             self.checkValid();
2399             MlirOperation firstOperation =
2400                 mlirBlockGetFirstOperation(self.get());
2401             return PyOperationIterator(self.getParentOperation(),
2402                                        firstOperation);
2403           },
2404           "Iterates over operations in the block.")
2405       .def("__eq__",
2406            [](PyBlock &self, PyBlock &other) {
2407              return self.get().ptr == other.get().ptr;
2408            })
2409       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2410       .def(
2411           "__str__",
2412           [](PyBlock &self) {
2413             self.checkValid();
2414             PyPrintAccumulator printAccum;
2415             mlirBlockPrint(self.get(), printAccum.getCallback(),
2416                            printAccum.getUserData());
2417             return printAccum.join();
2418           },
2419           "Returns the assembly form of the block.")
2420       .def(
2421           "append",
2422           [](PyBlock &self, PyOperationBase &operation) {
2423             if (operation.getOperation().isAttached())
2424               operation.getOperation().detachFromParent();
2425 
2426             MlirOperation mlirOperation = operation.getOperation().get();
2427             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2428             operation.getOperation().setAttached(
2429                 self.getParentOperation().getObject());
2430           },
2431           "Appends an operation to this block. If the operation is currently "
2432           "in another block, it will be moved.");
2433 
2434   //----------------------------------------------------------------------------
2435   // Mapping of PyInsertionPoint.
2436   //----------------------------------------------------------------------------
2437 
2438   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2439       .def(py::init<PyBlock &>(), py::arg("block"),
2440            "Inserts after the last operation but still inside the block.")
2441       .def("__enter__", &PyInsertionPoint::contextEnter)
2442       .def("__exit__", &PyInsertionPoint::contextExit)
2443       .def_property_readonly_static(
2444           "current",
2445           [](py::object & /*class*/) {
2446             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2447             if (!ip)
2448               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2449             return ip;
2450           },
2451           "Gets the InsertionPoint bound to the current thread or raises "
2452           "ValueError if none has been set")
2453       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2454            "Inserts before a referenced operation.")
2455       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2456                   py::arg("block"), "Inserts at the beginning of the block.")
2457       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2458                   py::arg("block"), "Inserts before the block terminator.")
2459       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2460            "Inserts an operation.")
2461       .def_property_readonly(
2462           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2463           "Returns the block that this InsertionPoint points to.");
2464 
2465   //----------------------------------------------------------------------------
2466   // Mapping of PyAttribute.
2467   //----------------------------------------------------------------------------
2468   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2469       // Delegate to the PyAttribute copy constructor, which will also lifetime
2470       // extend the backing context which owns the MlirAttribute.
2471       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2472            "Casts the passed attribute to the generic Attribute")
2473       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2474                              &PyAttribute::getCapsule)
2475       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2476       .def_static(
2477           "parse",
2478           [](std::string attrSpec, DefaultingPyMlirContext context) {
2479             MlirAttribute type = mlirAttributeParseGet(
2480                 context->get(), toMlirStringRef(attrSpec));
2481             // TODO: Rework error reporting once diagnostic engine is exposed
2482             // in C API.
2483             if (mlirAttributeIsNull(type)) {
2484               throw SetPyError(PyExc_ValueError,
2485                                Twine("Unable to parse attribute: '") +
2486                                    attrSpec + "'");
2487             }
2488             return PyAttribute(context->getRef(), type);
2489           },
2490           py::arg("asm"), py::arg("context") = py::none(),
2491           "Parses an attribute from an assembly form")
2492       .def_property_readonly(
2493           "context",
2494           [](PyAttribute &self) { return self.getContext().getObject(); },
2495           "Context that owns the Attribute")
2496       .def_property_readonly("type",
2497                              [](PyAttribute &self) {
2498                                return PyType(self.getContext()->getRef(),
2499                                              mlirAttributeGetType(self));
2500                              })
2501       .def(
2502           "get_named",
2503           [](PyAttribute &self, std::string name) {
2504             return PyNamedAttribute(self, std::move(name));
2505           },
2506           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2507       .def("__eq__",
2508            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2509       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2510       .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; })
2511       .def(
2512           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2513           kDumpDocstring)
2514       .def(
2515           "__str__",
2516           [](PyAttribute &self) {
2517             PyPrintAccumulator printAccum;
2518             mlirAttributePrint(self, printAccum.getCallback(),
2519                                printAccum.getUserData());
2520             return printAccum.join();
2521           },
2522           "Returns the assembly form of the Attribute.")
2523       .def("__repr__", [](PyAttribute &self) {
2524         // Generally, assembly formats are not printed for __repr__ because
2525         // this can cause exceptionally long debug output and exceptions.
2526         // However, attribute values are generally considered useful and are
2527         // printed. This may need to be re-evaluated if debug dumps end up
2528         // being excessive.
2529         PyPrintAccumulator printAccum;
2530         printAccum.parts.append("Attribute(");
2531         mlirAttributePrint(self, printAccum.getCallback(),
2532                            printAccum.getUserData());
2533         printAccum.parts.append(")");
2534         return printAccum.join();
2535       });
2536 
2537   //----------------------------------------------------------------------------
2538   // Mapping of PyNamedAttribute
2539   //----------------------------------------------------------------------------
2540   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2541       .def("__repr__",
2542            [](PyNamedAttribute &self) {
2543              PyPrintAccumulator printAccum;
2544              printAccum.parts.append("NamedAttribute(");
2545              printAccum.parts.append(
2546                  mlirIdentifierStr(self.namedAttr.name).data);
2547              printAccum.parts.append("=");
2548              mlirAttributePrint(self.namedAttr.attribute,
2549                                 printAccum.getCallback(),
2550                                 printAccum.getUserData());
2551              printAccum.parts.append(")");
2552              return printAccum.join();
2553            })
2554       .def_property_readonly(
2555           "name",
2556           [](PyNamedAttribute &self) {
2557             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2558                            mlirIdentifierStr(self.namedAttr.name).length);
2559           },
2560           "The name of the NamedAttribute binding")
2561       .def_property_readonly(
2562           "attr",
2563           [](PyNamedAttribute &self) {
2564             // TODO: When named attribute is removed/refactored, also remove
2565             // this constructor (it does an inefficient table lookup).
2566             auto contextRef = PyMlirContext::forContext(
2567                 mlirAttributeGetContext(self.namedAttr.attribute));
2568             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2569           },
2570           py::keep_alive<0, 1>(),
2571           "The underlying generic attribute of the NamedAttribute binding");
2572 
2573   //----------------------------------------------------------------------------
2574   // Mapping of PyType.
2575   //----------------------------------------------------------------------------
2576   py::class_<PyType>(m, "Type", py::module_local())
2577       // Delegate to the PyType copy constructor, which will also lifetime
2578       // extend the backing context which owns the MlirType.
2579       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2580            "Casts the passed type to the generic Type")
2581       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2582       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2583       .def_static(
2584           "parse",
2585           [](std::string typeSpec, DefaultingPyMlirContext context) {
2586             MlirType type =
2587                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2588             // TODO: Rework error reporting once diagnostic engine is exposed
2589             // in C API.
2590             if (mlirTypeIsNull(type)) {
2591               throw SetPyError(PyExc_ValueError,
2592                                Twine("Unable to parse type: '") + typeSpec +
2593                                    "'");
2594             }
2595             return PyType(context->getRef(), type);
2596           },
2597           py::arg("asm"), py::arg("context") = py::none(),
2598           kContextParseTypeDocstring)
2599       .def_property_readonly(
2600           "context", [](PyType &self) { return self.getContext().getObject(); },
2601           "Context that owns the Type")
2602       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2603       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2604       .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; })
2605       .def(
2606           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2607       .def(
2608           "__str__",
2609           [](PyType &self) {
2610             PyPrintAccumulator printAccum;
2611             mlirTypePrint(self, printAccum.getCallback(),
2612                           printAccum.getUserData());
2613             return printAccum.join();
2614           },
2615           "Returns the assembly form of the type.")
2616       .def("__repr__", [](PyType &self) {
2617         // Generally, assembly formats are not printed for __repr__ because
2618         // this can cause exceptionally long debug output and exceptions.
2619         // However, types are an exception as they typically have compact
2620         // assembly forms and printing them is useful.
2621         PyPrintAccumulator printAccum;
2622         printAccum.parts.append("Type(");
2623         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2624         printAccum.parts.append(")");
2625         return printAccum.join();
2626       });
2627 
2628   //----------------------------------------------------------------------------
2629   // Mapping of Value.
2630   //----------------------------------------------------------------------------
2631   py::class_<PyValue>(m, "Value", py::module_local())
2632       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2633       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2634       .def_property_readonly(
2635           "context",
2636           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2637           "Context in which the value lives.")
2638       .def(
2639           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2640           kDumpDocstring)
2641       .def_property_readonly(
2642           "owner",
2643           [](PyValue &self) {
2644             assert(mlirOperationEqual(self.getParentOperation()->get(),
2645                                       mlirOpResultGetOwner(self.get())) &&
2646                    "expected the owner of the value in Python to match that in "
2647                    "the IR");
2648             return self.getParentOperation().getObject();
2649           })
2650       .def("__eq__",
2651            [](PyValue &self, PyValue &other) {
2652              return self.get().ptr == other.get().ptr;
2653            })
2654       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2655       .def(
2656           "__str__",
2657           [](PyValue &self) {
2658             PyPrintAccumulator printAccum;
2659             printAccum.parts.append("Value(");
2660             mlirValuePrint(self.get(), printAccum.getCallback(),
2661                            printAccum.getUserData());
2662             printAccum.parts.append(")");
2663             return printAccum.join();
2664           },
2665           kValueDunderStrDocstring)
2666       .def_property_readonly("type", [](PyValue &self) {
2667         return PyType(self.getParentOperation()->getContext(),
2668                       mlirValueGetType(self.get()));
2669       });
2670   PyBlockArgument::bind(m);
2671   PyOpResult::bind(m);
2672 
2673   // Container bindings.
2674   PyBlockArgumentList::bind(m);
2675   PyBlockIterator::bind(m);
2676   PyBlockList::bind(m);
2677   PyOperationIterator::bind(m);
2678   PyOperationList::bind(m);
2679   PyOpAttributeMap::bind(m);
2680   PyOpOperandList::bind(m);
2681   PyOpResultList::bind(m);
2682   PyRegionIterator::bind(m);
2683   PyRegionList::bind(m);
2684 
2685   // Debug bindings.
2686   PyGlobalDebugFlag::bind(m);
2687 }
2688