1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96 )";
97 
98 static const char kOperationGetAsmDocstring[] =
99     R"(Gets the assembly form of the operation with all options available.
100 
101 Args:
102   binary: Whether to return a bytes (True) or str (False) object. Defaults to
103     False.
104   ... others ...: See the print() method for common keyword arguments for
105     configuring the printout.
106 Returns:
107   Either a bytes or str object, depending on the setting of the 'binary'
108   argument.
109 )";
110 
111 static const char kOperationStrDunderDocstring[] =
112     R"(Gets the assembly form of the operation with default options.
113 
114 If more advanced control over the assembly formatting or I/O options is needed,
115 use the dedicated print or get_asm method, which supports keyword arguments to
116 customize behavior.
117 )";
118 
119 static const char kDumpDocstring[] =
120     R"(Dumps a debug representation of the object to stderr.)";
121 
122 static const char kAppendBlockDocstring[] =
123     R"(Appends a new block, with argument types as positional args.
124 
125 Returns:
126   The created block.
127 )";
128 
129 static const char kValueDunderStrDocstring[] =
130     R"(Returns the string form of the value.
131 
132 If the value is a block argument, this is the assembly form of its type and the
133 position in the argument list. If the value is an operation result, this is
134 equivalent to printing the operation that produced it.
135 )";
136 
137 //------------------------------------------------------------------------------
138 // Utilities.
139 //------------------------------------------------------------------------------
140 
141 /// Helper for creating an @classmethod.
142 template <class Func, typename... Args>
143 py::object classmethod(Func f, Args... args) {
144   py::object cf = py::cpp_function(f, args...);
145   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
146 }
147 
148 static py::object
149 createCustomDialectWrapper(const std::string &dialectNamespace,
150                            py::object dialectDescriptor) {
151   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
152   if (!dialectClass) {
153     // Use the base class.
154     return py::cast(PyDialect(std::move(dialectDescriptor)));
155   }
156 
157   // Create the custom implementation.
158   return (*dialectClass)(std::move(dialectDescriptor));
159 }
160 
161 static MlirStringRef toMlirStringRef(const std::string &s) {
162   return mlirStringRefCreate(s.data(), s.size());
163 }
164 
165 /// Wrapper for the global LLVM debugging flag.
166 struct PyGlobalDebugFlag {
167   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
168 
169   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
170 
171   static void bind(py::module &m) {
172     // Debug flags.
173     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
174         .def_property_static("flag", &PyGlobalDebugFlag::get,
175                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
176   }
177 };
178 
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182 
183 namespace {
184 
185 class PyRegionIterator {
186 public:
187   PyRegionIterator(PyOperationRef operation)
188       : operation(std::move(operation)) {}
189 
190   PyRegionIterator &dunderIter() { return *this; }
191 
192   PyRegion dunderNext() {
193     operation->checkValid();
194     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195       throw py::stop_iteration();
196     }
197     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198     return PyRegion(operation, region);
199   }
200 
201   static void bind(py::module &m) {
202     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
203         .def("__iter__", &PyRegionIterator::dunderIter)
204         .def("__next__", &PyRegionIterator::dunderNext);
205   }
206 
207 private:
208   PyOperationRef operation;
209   int nextIndex = 0;
210 };
211 
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
216   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217 
218   intptr_t dunderLen() {
219     operation->checkValid();
220     return mlirOperationGetNumRegions(operation->get());
221   }
222 
223   PyRegion dunderGetItem(intptr_t index) {
224     // dunderLen checks validity.
225     if (index < 0 || index >= dunderLen()) {
226       throw SetPyError(PyExc_IndexError,
227                        "attempt to access out of bounds region");
228     }
229     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230     return PyRegion(operation, region);
231   }
232 
233   static void bind(py::module &m) {
234     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
235         .def("__len__", &PyRegionList::dunderLen)
236         .def("__getitem__", &PyRegionList::dunderGetItem);
237   }
238 
239 private:
240   PyOperationRef operation;
241 };
242 
243 class PyBlockIterator {
244 public:
245   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246       : operation(std::move(operation)), next(next) {}
247 
248   PyBlockIterator &dunderIter() { return *this; }
249 
250   PyBlock dunderNext() {
251     operation->checkValid();
252     if (mlirBlockIsNull(next)) {
253       throw py::stop_iteration();
254     }
255 
256     PyBlock returnBlock(operation, next);
257     next = mlirBlockGetNextInRegion(next);
258     return returnBlock;
259   }
260 
261   static void bind(py::module &m) {
262     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
263         .def("__iter__", &PyBlockIterator::dunderIter)
264         .def("__next__", &PyBlockIterator::dunderNext);
265   }
266 
267 private:
268   PyOperationRef operation;
269   MlirBlock next;
270 };
271 
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
277   PyBlockList(PyOperationRef operation, MlirRegion region)
278       : operation(std::move(operation)), region(region) {}
279 
280   PyBlockIterator dunderIter() {
281     operation->checkValid();
282     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283   }
284 
285   intptr_t dunderLen() {
286     operation->checkValid();
287     intptr_t count = 0;
288     MlirBlock block = mlirRegionGetFirstBlock(region);
289     while (!mlirBlockIsNull(block)) {
290       count += 1;
291       block = mlirBlockGetNextInRegion(block);
292     }
293     return count;
294   }
295 
296   PyBlock dunderGetItem(intptr_t index) {
297     operation->checkValid();
298     if (index < 0) {
299       throw SetPyError(PyExc_IndexError,
300                        "attempt to access out of bounds block");
301     }
302     MlirBlock block = mlirRegionGetFirstBlock(region);
303     while (!mlirBlockIsNull(block)) {
304       if (index == 0) {
305         return PyBlock(operation, block);
306       }
307       block = mlirBlockGetNextInRegion(block);
308       index -= 1;
309     }
310     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311   }
312 
313   PyBlock appendBlock(py::args pyArgTypes) {
314     operation->checkValid();
315     llvm::SmallVector<MlirType, 4> argTypes;
316     argTypes.reserve(pyArgTypes.size());
317     for (auto &pyArg : pyArgTypes) {
318       argTypes.push_back(pyArg.cast<PyType &>());
319     }
320 
321     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322     mlirRegionAppendOwnedBlock(region, block);
323     return PyBlock(operation, block);
324   }
325 
326   static void bind(py::module &m) {
327     py::class_<PyBlockList>(m, "BlockList", py::module_local())
328         .def("__getitem__", &PyBlockList::dunderGetItem)
329         .def("__iter__", &PyBlockList::dunderIter)
330         .def("__len__", &PyBlockList::dunderLen)
331         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332   }
333 
334 private:
335   PyOperationRef operation;
336   MlirRegion region;
337 };
338 
339 class PyOperationIterator {
340 public:
341   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342       : parentOperation(std::move(parentOperation)), next(next) {}
343 
344   PyOperationIterator &dunderIter() { return *this; }
345 
346   py::object dunderNext() {
347     parentOperation->checkValid();
348     if (mlirOperationIsNull(next)) {
349       throw py::stop_iteration();
350     }
351 
352     PyOperationRef returnOperation =
353         PyOperation::forOperation(parentOperation->getContext(), next);
354     next = mlirOperationGetNextInBlock(next);
355     return returnOperation->createOpView();
356   }
357 
358   static void bind(py::module &m) {
359     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
360         .def("__iter__", &PyOperationIterator::dunderIter)
361         .def("__next__", &PyOperationIterator::dunderNext);
362   }
363 
364 private:
365   PyOperationRef parentOperation;
366   MlirOperation next;
367 };
368 
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
375   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376       : parentOperation(std::move(parentOperation)), block(block) {}
377 
378   PyOperationIterator dunderIter() {
379     parentOperation->checkValid();
380     return PyOperationIterator(parentOperation,
381                                mlirBlockGetFirstOperation(block));
382   }
383 
384   intptr_t dunderLen() {
385     parentOperation->checkValid();
386     intptr_t count = 0;
387     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388     while (!mlirOperationIsNull(childOp)) {
389       count += 1;
390       childOp = mlirOperationGetNextInBlock(childOp);
391     }
392     return count;
393   }
394 
395   py::object dunderGetItem(intptr_t index) {
396     parentOperation->checkValid();
397     if (index < 0) {
398       throw SetPyError(PyExc_IndexError,
399                        "attempt to access out of bounds operation");
400     }
401     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402     while (!mlirOperationIsNull(childOp)) {
403       if (index == 0) {
404         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405             ->createOpView();
406       }
407       childOp = mlirOperationGetNextInBlock(childOp);
408       index -= 1;
409     }
410     throw SetPyError(PyExc_IndexError,
411                      "attempt to access out of bounds operation");
412   }
413 
414   static void bind(py::module &m) {
415     py::class_<PyOperationList>(m, "OperationList", py::module_local())
416         .def("__getitem__", &PyOperationList::dunderGetItem)
417         .def("__iter__", &PyOperationList::dunderIter)
418         .def("__len__", &PyOperationList::dunderLen);
419   }
420 
421 private:
422   PyOperationRef parentOperation;
423   MlirBlock block;
424 };
425 
426 } // namespace
427 
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431 
432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433   py::gil_scoped_acquire acquire;
434   auto &liveContexts = getLiveContexts();
435   liveContexts[context.ptr] = this;
436 }
437 
438 PyMlirContext::~PyMlirContext() {
439   // Note that the only public way to construct an instance is via the
440   // forContext method, which always puts the associated handle into
441   // liveContexts.
442   py::gil_scoped_acquire acquire;
443   getLiveContexts().erase(context.ptr);
444   mlirContextDestroy(context);
445 }
446 
447 py::object PyMlirContext::getCapsule() {
448   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450 
451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453   if (mlirContextIsNull(rawContext))
454     throw py::error_already_set();
455   return forContext(rawContext).releaseObject();
456 }
457 
458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459   MlirContext context = mlirContextCreate();
460   mlirRegisterAllDialects(context);
461   return new PyMlirContext(context);
462 }
463 
464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465   py::gil_scoped_acquire acquire;
466   auto &liveContexts = getLiveContexts();
467   auto it = liveContexts.find(context.ptr);
468   if (it == liveContexts.end()) {
469     // Create.
470     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471     py::object pyRef = py::cast(unownedContextWrapper);
472     assert(pyRef && "cast to py::object failed");
473     liveContexts[context.ptr] = unownedContextWrapper;
474     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475   }
476   // Use existing.
477   py::object pyRef = py::cast(it->second);
478   return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480 
481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482   static LiveContextMap liveContexts;
483   return liveContexts;
484 }
485 
486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487 
488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489 
490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491 
492 pybind11::object PyMlirContext::contextEnter() {
493   return PyThreadContextEntry::pushContext(*this);
494 }
495 
496 void PyMlirContext::contextExit(pybind11::object excType,
497                                 pybind11::object excVal,
498                                 pybind11::object excTb) {
499   PyThreadContextEntry::popContext(*this);
500 }
501 
502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504   if (!context) {
505     throw SetPyError(
506         PyExc_RuntimeError,
507         "An MLIR function requires a Context but none was provided in the call "
508         "or from the surrounding environment. Either pass to the function with "
509         "a 'context=' argument or establish a default using 'with Context():'");
510   }
511   return *context;
512 }
513 
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517 
518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519   static thread_local std::vector<PyThreadContextEntry> stack;
520   return stack;
521 }
522 
523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524   auto &stack = getStack();
525   if (stack.empty())
526     return nullptr;
527   return &stack.back();
528 }
529 
530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531                                 py::object insertionPoint,
532                                 py::object location) {
533   auto &stack = getStack();
534   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535                      std::move(location));
536   // If the new stack has more than one entry and the context of the new top
537   // entry matches the previous, copy the insertionPoint and location from the
538   // previous entry if missing from the new top entry.
539   if (stack.size() > 1) {
540     auto &prev = *(stack.rbegin() + 1);
541     auto &current = stack.back();
542     if (current.context.is(prev.context)) {
543       // Default non-context objects from the previous entry.
544       if (!current.insertionPoint)
545         current.insertionPoint = prev.insertionPoint;
546       if (!current.location)
547         current.location = prev.location;
548     }
549   }
550 }
551 
552 PyMlirContext *PyThreadContextEntry::getContext() {
553   if (!context)
554     return nullptr;
555   return py::cast<PyMlirContext *>(context);
556 }
557 
558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559   if (!insertionPoint)
560     return nullptr;
561   return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563 
564 PyLocation *PyThreadContextEntry::getLocation() {
565   if (!location)
566     return nullptr;
567   return py::cast<PyLocation *>(location);
568 }
569 
570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571   auto *tos = getTopOfStack();
572   return tos ? tos->getContext() : nullptr;
573 }
574 
575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576   auto *tos = getTopOfStack();
577   return tos ? tos->getInsertionPoint() : nullptr;
578 }
579 
580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getLocation() : nullptr;
583 }
584 
585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586   py::object contextObj = py::cast(context);
587   push(FrameKind::Context, /*context=*/contextObj,
588        /*insertionPoint=*/py::object(),
589        /*location=*/py::object());
590   return contextObj;
591 }
592 
593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594   auto &stack = getStack();
595   if (stack.empty())
596     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600   stack.pop_back();
601 }
602 
603 py::object
604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605   py::object contextObj =
606       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607   py::object insertionPointObj = py::cast(insertionPoint);
608   push(FrameKind::InsertionPoint,
609        /*context=*/contextObj,
610        /*insertionPoint=*/insertionPointObj,
611        /*location=*/py::object());
612   return insertionPointObj;
613 }
614 
615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616   auto &stack = getStack();
617   if (stack.empty())
618     throw SetPyError(PyExc_RuntimeError,
619                      "Unbalanced InsertionPoint enter/exit");
620   auto &tos = stack.back();
621   if (tos.frameKind != FrameKind::InsertionPoint &&
622       tos.getInsertionPoint() != &insertionPoint)
623     throw SetPyError(PyExc_RuntimeError,
624                      "Unbalanced InsertionPoint enter/exit");
625   stack.pop_back();
626 }
627 
628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629   py::object contextObj = location.getContext().getObject();
630   py::object locationObj = py::cast(location);
631   push(FrameKind::Location, /*context=*/contextObj,
632        /*insertionPoint=*/py::object(),
633        /*location=*/locationObj);
634   return locationObj;
635 }
636 
637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638   auto &stack = getStack();
639   if (stack.empty())
640     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641   auto &tos = stack.back();
642   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644   stack.pop_back();
645 }
646 
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650 
651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652                                          bool attrError) {
653   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
654                                                     {key.data(), key.size()});
655   if (mlirDialectIsNull(dialect)) {
656     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
657                      Twine("Dialect '") + key + "' not found");
658   }
659   return dialect;
660 }
661 
662 //------------------------------------------------------------------------------
663 // PyLocation
664 //------------------------------------------------------------------------------
665 
666 py::object PyLocation::getCapsule() {
667   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
668 }
669 
670 PyLocation PyLocation::createFromCapsule(py::object capsule) {
671   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
672   if (mlirLocationIsNull(rawLoc))
673     throw py::error_already_set();
674   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
675                     rawLoc);
676 }
677 
678 py::object PyLocation::contextEnter() {
679   return PyThreadContextEntry::pushLocation(*this);
680 }
681 
682 void PyLocation::contextExit(py::object excType, py::object excVal,
683                              py::object excTb) {
684   PyThreadContextEntry::popLocation(*this);
685 }
686 
687 PyLocation &DefaultingPyLocation::resolve() {
688   auto *location = PyThreadContextEntry::getDefaultLocation();
689   if (!location) {
690     throw SetPyError(
691         PyExc_RuntimeError,
692         "An MLIR function requires a Location but none was provided in the "
693         "call or from the surrounding environment. Either pass to the function "
694         "with a 'loc=' argument or establish a default using 'with loc:'");
695   }
696   return *location;
697 }
698 
699 //------------------------------------------------------------------------------
700 // PyModule
701 //------------------------------------------------------------------------------
702 
703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
704     : BaseContextObject(std::move(contextRef)), module(module) {}
705 
706 PyModule::~PyModule() {
707   py::gil_scoped_acquire acquire;
708   auto &liveModules = getContext()->liveModules;
709   assert(liveModules.count(module.ptr) == 1 &&
710          "destroying module not in live map");
711   liveModules.erase(module.ptr);
712   mlirModuleDestroy(module);
713 }
714 
715 PyModuleRef PyModule::forModule(MlirModule module) {
716   MlirContext context = mlirModuleGetContext(module);
717   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
718 
719   py::gil_scoped_acquire acquire;
720   auto &liveModules = contextRef->liveModules;
721   auto it = liveModules.find(module.ptr);
722   if (it == liveModules.end()) {
723     // Create.
724     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
725     // Note that the default return value policy on cast is automatic_reference,
726     // which does not take ownership (delete will not be called).
727     // Just be explicit.
728     py::object pyRef =
729         py::cast(unownedModule, py::return_value_policy::take_ownership);
730     unownedModule->handle = pyRef;
731     liveModules[module.ptr] =
732         std::make_pair(unownedModule->handle, unownedModule);
733     return PyModuleRef(unownedModule, std::move(pyRef));
734   }
735   // Use existing.
736   PyModule *existing = it->second.second;
737   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
738   return PyModuleRef(existing, std::move(pyRef));
739 }
740 
741 py::object PyModule::createFromCapsule(py::object capsule) {
742   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
743   if (mlirModuleIsNull(rawModule))
744     throw py::error_already_set();
745   return forModule(rawModule).releaseObject();
746 }
747 
748 py::object PyModule::getCapsule() {
749   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
750 }
751 
752 //------------------------------------------------------------------------------
753 // PyOperation
754 //------------------------------------------------------------------------------
755 
756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
757     : BaseContextObject(std::move(contextRef)), operation(operation) {}
758 
759 PyOperation::~PyOperation() {
760   // If the operation has already been invalidated there is nothing to do.
761   if (!valid)
762     return;
763   auto &liveOperations = getContext()->liveOperations;
764   assert(liveOperations.count(operation.ptr) == 1 &&
765          "destroying operation not in live map");
766   liveOperations.erase(operation.ptr);
767   if (!isAttached()) {
768     mlirOperationDestroy(operation);
769   }
770 }
771 
772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773                                            MlirOperation operation,
774                                            py::object parentKeepAlive) {
775   auto &liveOperations = contextRef->liveOperations;
776   // Create.
777   PyOperation *unownedOperation =
778       new PyOperation(std::move(contextRef), operation);
779   // Note that the default return value policy on cast is automatic_reference,
780   // which does not take ownership (delete will not be called).
781   // Just be explicit.
782   py::object pyRef =
783       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784   unownedOperation->handle = pyRef;
785   if (parentKeepAlive) {
786     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787   }
788   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789   return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791 
792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793                                          MlirOperation operation,
794                                          py::object parentKeepAlive) {
795   auto &liveOperations = contextRef->liveOperations;
796   auto it = liveOperations.find(operation.ptr);
797   if (it == liveOperations.end()) {
798     // Create.
799     return createInstance(std::move(contextRef), operation,
800                           std::move(parentKeepAlive));
801   }
802   // Use existing.
803   PyOperation *existing = it->second.second;
804   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805   return PyOperationRef(existing, std::move(pyRef));
806 }
807 
808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809                                            MlirOperation operation,
810                                            py::object parentKeepAlive) {
811   auto &liveOperations = contextRef->liveOperations;
812   assert(liveOperations.count(operation.ptr) == 0 &&
813          "cannot create detached operation that already exists");
814   (void)liveOperations;
815 
816   PyOperationRef created = createInstance(std::move(contextRef), operation,
817                                           std::move(parentKeepAlive));
818   created->attached = false;
819   return created;
820 }
821 
822 void PyOperation::checkValid() const {
823   if (!valid) {
824     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825   }
826 }
827 
828 void PyOperationBase::print(py::object fileObject, bool binary,
829                             llvm::Optional<int64_t> largeElementsLimit,
830                             bool enableDebugInfo, bool prettyDebugInfo,
831                             bool printGenericOpForm, bool useLocalScope) {
832   PyOperation &operation = getOperation();
833   operation.checkValid();
834   if (fileObject.is_none())
835     fileObject = py::module::import("sys").attr("stdout");
836 
837   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838     fileObject.attr("write")("// Verification failed, printing generic form\n");
839     printGenericOpForm = true;
840   }
841 
842   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843   if (largeElementsLimit)
844     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845   if (enableDebugInfo)
846     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847   if (printGenericOpForm)
848     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849 
850   PyFileAccumulator accum(fileObject, binary);
851   py::gil_scoped_release();
852   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853                               accum.getUserData());
854   mlirOpPrintingFlagsDestroy(flags);
855 }
856 
857 py::object PyOperationBase::getAsm(bool binary,
858                                    llvm::Optional<int64_t> largeElementsLimit,
859                                    bool enableDebugInfo, bool prettyDebugInfo,
860                                    bool printGenericOpForm,
861                                    bool useLocalScope) {
862   py::object fileObject;
863   if (binary) {
864     fileObject = py::module::import("io").attr("BytesIO")();
865   } else {
866     fileObject = py::module::import("io").attr("StringIO")();
867   }
868   print(fileObject, /*binary=*/binary,
869         /*largeElementsLimit=*/largeElementsLimit,
870         /*enableDebugInfo=*/enableDebugInfo,
871         /*prettyDebugInfo=*/prettyDebugInfo,
872         /*printGenericOpForm=*/printGenericOpForm,
873         /*useLocalScope=*/useLocalScope);
874 
875   return fileObject.attr("getvalue")();
876 }
877 
878 void PyOperationBase::moveAfter(PyOperationBase &other) {
879   PyOperation &operation = getOperation();
880   PyOperation &otherOp = other.getOperation();
881   operation.checkValid();
882   otherOp.checkValid();
883   mlirOperationMoveAfter(operation, otherOp);
884   operation.parentKeepAlive = otherOp.parentKeepAlive;
885 }
886 
887 void PyOperationBase::moveBefore(PyOperationBase &other) {
888   PyOperation &operation = getOperation();
889   PyOperation &otherOp = other.getOperation();
890   operation.checkValid();
891   otherOp.checkValid();
892   mlirOperationMoveBefore(operation, otherOp);
893   operation.parentKeepAlive = otherOp.parentKeepAlive;
894 }
895 
896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
897   checkValid();
898   if (!isAttached())
899     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
900   MlirOperation operation = mlirOperationGetParentOperation(get());
901   if (mlirOperationIsNull(operation))
902     return {};
903   return PyOperation::forOperation(getContext(), operation);
904 }
905 
906 PyBlock PyOperation::getBlock() {
907   checkValid();
908   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
909   MlirBlock block = mlirOperationGetBlock(get());
910   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
911   assert(parentOperation && "Operation has no parent");
912   return PyBlock{std::move(*parentOperation), block};
913 }
914 
915 py::object PyOperation::getCapsule() {
916   checkValid();
917   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
918 }
919 
920 py::object PyOperation::createFromCapsule(py::object capsule) {
921   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
922   if (mlirOperationIsNull(rawOperation))
923     throw py::error_already_set();
924   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
925   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
926       .releaseObject();
927 }
928 
929 py::object PyOperation::create(
930     std::string name, llvm::Optional<std::vector<PyType *>> results,
931     llvm::Optional<std::vector<PyValue *>> operands,
932     llvm::Optional<py::dict> attributes,
933     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
934     DefaultingPyLocation location, py::object maybeIp) {
935   llvm::SmallVector<MlirValue, 4> mlirOperands;
936   llvm::SmallVector<MlirType, 4> mlirResults;
937   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
938   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
939 
940   // General parameter validation.
941   if (regions < 0)
942     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
943 
944   // Unpack/validate operands.
945   if (operands) {
946     mlirOperands.reserve(operands->size());
947     for (PyValue *operand : *operands) {
948       if (!operand)
949         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
950       mlirOperands.push_back(operand->get());
951     }
952   }
953 
954   // Unpack/validate results.
955   if (results) {
956     mlirResults.reserve(results->size());
957     for (PyType *result : *results) {
958       // TODO: Verify result type originate from the same context.
959       if (!result)
960         throw SetPyError(PyExc_ValueError, "result type cannot be None");
961       mlirResults.push_back(*result);
962     }
963   }
964   // Unpack/validate attributes.
965   if (attributes) {
966     mlirAttributes.reserve(attributes->size());
967     for (auto &it : *attributes) {
968       std::string key;
969       try {
970         key = it.first.cast<std::string>();
971       } catch (py::cast_error &err) {
972         std::string msg = "Invalid attribute key (not a string) when "
973                           "attempting to create the operation \"" +
974                           name + "\" (" + err.what() + ")";
975         throw py::cast_error(msg);
976       }
977       try {
978         auto &attribute = it.second.cast<PyAttribute &>();
979         // TODO: Verify attribute originates from the same context.
980         mlirAttributes.emplace_back(std::move(key), attribute);
981       } catch (py::reference_cast_error &) {
982         // This exception seems thrown when the value is "None".
983         std::string msg =
984             "Found an invalid (`None`?) attribute value for the key \"" + key +
985             "\" when attempting to create the operation \"" + name + "\"";
986         throw py::cast_error(msg);
987       } catch (py::cast_error &err) {
988         std::string msg = "Invalid attribute value for the key \"" + key +
989                           "\" when attempting to create the operation \"" +
990                           name + "\" (" + err.what() + ")";
991         throw py::cast_error(msg);
992       }
993     }
994   }
995   // Unpack/validate successors.
996   if (successors) {
997     mlirSuccessors.reserve(successors->size());
998     for (auto *successor : *successors) {
999       // TODO: Verify successor originate from the same context.
1000       if (!successor)
1001         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1002       mlirSuccessors.push_back(successor->get());
1003     }
1004   }
1005 
1006   // Apply unpacked/validated to the operation state. Beyond this
1007   // point, exceptions cannot be thrown or else the state will leak.
1008   MlirOperationState state =
1009       mlirOperationStateGet(toMlirStringRef(name), location);
1010   if (!mlirOperands.empty())
1011     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1012                                   mlirOperands.data());
1013   if (!mlirResults.empty())
1014     mlirOperationStateAddResults(&state, mlirResults.size(),
1015                                  mlirResults.data());
1016   if (!mlirAttributes.empty()) {
1017     // Note that the attribute names directly reference bytes in
1018     // mlirAttributes, so that vector must not be changed from here
1019     // on.
1020     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1021     mlirNamedAttributes.reserve(mlirAttributes.size());
1022     for (auto &it : mlirAttributes)
1023       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1024           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1025                             toMlirStringRef(it.first)),
1026           it.second));
1027     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1028                                     mlirNamedAttributes.data());
1029   }
1030   if (!mlirSuccessors.empty())
1031     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1032                                     mlirSuccessors.data());
1033   if (regions) {
1034     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1035     mlirRegions.resize(regions);
1036     for (int i = 0; i < regions; ++i)
1037       mlirRegions[i] = mlirRegionCreate();
1038     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1039                                       mlirRegions.data());
1040   }
1041 
1042   // Construct the operation.
1043   MlirOperation operation = mlirOperationCreate(&state);
1044   PyOperationRef created =
1045       PyOperation::createDetached(location->getContext(), operation);
1046 
1047   // InsertPoint active?
1048   if (!maybeIp.is(py::cast(false))) {
1049     PyInsertionPoint *ip;
1050     if (maybeIp.is_none()) {
1051       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1052     } else {
1053       ip = py::cast<PyInsertionPoint *>(maybeIp);
1054     }
1055     if (ip)
1056       ip->insert(*created.get());
1057   }
1058 
1059   return created->createOpView();
1060 }
1061 
1062 py::object PyOperation::createOpView() {
1063   checkValid();
1064   MlirIdentifier ident = mlirOperationGetName(get());
1065   MlirStringRef identStr = mlirIdentifierStr(ident);
1066   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1067       StringRef(identStr.data, identStr.length));
1068   if (opViewClass)
1069     return (*opViewClass)(getRef().getObject());
1070   return py::cast(PyOpView(getRef().getObject()));
1071 }
1072 
1073 void PyOperation::erase() {
1074   checkValid();
1075   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1076   // Python reference to a child operation is live. All children should also
1077   // have their `valid` bit set to false.
1078   auto &liveOperations = getContext()->liveOperations;
1079   if (liveOperations.count(operation.ptr))
1080     liveOperations.erase(operation.ptr);
1081   mlirOperationDestroy(operation);
1082   valid = false;
1083 }
1084 
1085 //------------------------------------------------------------------------------
1086 // PyOpView
1087 //------------------------------------------------------------------------------
1088 
1089 py::object
1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1091                        py::list operandList,
1092                        llvm::Optional<py::dict> attributes,
1093                        llvm::Optional<std::vector<PyBlock *>> successors,
1094                        llvm::Optional<int> regions,
1095                        DefaultingPyLocation location, py::object maybeIp) {
1096   PyMlirContextRef context = location->getContext();
1097   // Class level operation construction metadata.
1098   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1099   // Operand and result segment specs are either none, which does no
1100   // variadic unpacking, or a list of ints with segment sizes, where each
1101   // element is either a positive number (typically 1 for a scalar) or -1 to
1102   // indicate that it is derived from the length of the same-indexed operand
1103   // or result (implying that it is a list at that position).
1104   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1105   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1106 
1107   std::vector<uint32_t> operandSegmentLengths;
1108   std::vector<uint32_t> resultSegmentLengths;
1109 
1110   // Validate/determine region count.
1111   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1112   int opMinRegionCount = std::get<0>(opRegionSpec);
1113   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1114   if (!regions) {
1115     regions = opMinRegionCount;
1116   }
1117   if (*regions < opMinRegionCount) {
1118     throw py::value_error(
1119         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1120          llvm::Twine(opMinRegionCount) +
1121          " regions but was built with regions=" + llvm::Twine(*regions))
1122             .str());
1123   }
1124   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1125     throw py::value_error(
1126         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1127          llvm::Twine(opMinRegionCount) +
1128          " regions but was built with regions=" + llvm::Twine(*regions))
1129             .str());
1130   }
1131 
1132   // Unpack results.
1133   std::vector<PyType *> resultTypes;
1134   resultTypes.reserve(resultTypeList.size());
1135   if (resultSegmentSpecObj.is_none()) {
1136     // Non-variadic result unpacking.
1137     for (auto it : llvm::enumerate(resultTypeList)) {
1138       try {
1139         resultTypes.push_back(py::cast<PyType *>(it.value()));
1140         if (!resultTypes.back())
1141           throw py::cast_error();
1142       } catch (py::cast_error &err) {
1143         throw py::value_error((llvm::Twine("Result ") +
1144                                llvm::Twine(it.index()) + " of operation \"" +
1145                                name + "\" must be a Type (" + err.what() + ")")
1146                                   .str());
1147       }
1148     }
1149   } else {
1150     // Sized result unpacking.
1151     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1152     if (resultSegmentSpec.size() != resultTypeList.size()) {
1153       throw py::value_error((llvm::Twine("Operation \"") + name +
1154                              "\" requires " +
1155                              llvm::Twine(resultSegmentSpec.size()) +
1156                              " result segments but was provided " +
1157                              llvm::Twine(resultTypeList.size()))
1158                                 .str());
1159     }
1160     resultSegmentLengths.reserve(resultTypeList.size());
1161     for (auto it :
1162          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1163       int segmentSpec = std::get<1>(it.value());
1164       if (segmentSpec == 1 || segmentSpec == 0) {
1165         // Unpack unary element.
1166         try {
1167           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1168           if (resultType) {
1169             resultTypes.push_back(resultType);
1170             resultSegmentLengths.push_back(1);
1171           } else if (segmentSpec == 0) {
1172             // Allowed to be optional.
1173             resultSegmentLengths.push_back(0);
1174           } else {
1175             throw py::cast_error("was None and result is not optional");
1176           }
1177         } catch (py::cast_error &err) {
1178           throw py::value_error((llvm::Twine("Result ") +
1179                                  llvm::Twine(it.index()) + " of operation \"" +
1180                                  name + "\" must be a Type (" + err.what() +
1181                                  ")")
1182                                     .str());
1183         }
1184       } else if (segmentSpec == -1) {
1185         // Unpack sequence by appending.
1186         try {
1187           if (std::get<0>(it.value()).is_none()) {
1188             // Treat it as an empty list.
1189             resultSegmentLengths.push_back(0);
1190           } else {
1191             // Unpack the list.
1192             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1193             for (py::object segmentItem : segment) {
1194               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1195               if (!resultTypes.back()) {
1196                 throw py::cast_error("contained a None item");
1197               }
1198             }
1199             resultSegmentLengths.push_back(segment.size());
1200           }
1201         } catch (std::exception &err) {
1202           // NOTE: Sloppy to be using a catch-all here, but there are at least
1203           // three different unrelated exceptions that can be thrown in the
1204           // above "casts". Just keep the scope above small and catch them all.
1205           throw py::value_error((llvm::Twine("Result ") +
1206                                  llvm::Twine(it.index()) + " of operation \"" +
1207                                  name + "\" must be a Sequence of Types (" +
1208                                  err.what() + ")")
1209                                     .str());
1210         }
1211       } else {
1212         throw py::value_error("Unexpected segment spec");
1213       }
1214     }
1215   }
1216 
1217   // Unpack operands.
1218   std::vector<PyValue *> operands;
1219   operands.reserve(operands.size());
1220   if (operandSegmentSpecObj.is_none()) {
1221     // Non-sized operand unpacking.
1222     for (auto it : llvm::enumerate(operandList)) {
1223       try {
1224         operands.push_back(py::cast<PyValue *>(it.value()));
1225         if (!operands.back())
1226           throw py::cast_error();
1227       } catch (py::cast_error &err) {
1228         throw py::value_error((llvm::Twine("Operand ") +
1229                                llvm::Twine(it.index()) + " of operation \"" +
1230                                name + "\" must be a Value (" + err.what() + ")")
1231                                   .str());
1232       }
1233     }
1234   } else {
1235     // Sized operand unpacking.
1236     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1237     if (operandSegmentSpec.size() != operandList.size()) {
1238       throw py::value_error((llvm::Twine("Operation \"") + name +
1239                              "\" requires " +
1240                              llvm::Twine(operandSegmentSpec.size()) +
1241                              "operand segments but was provided " +
1242                              llvm::Twine(operandList.size()))
1243                                 .str());
1244     }
1245     operandSegmentLengths.reserve(operandList.size());
1246     for (auto it :
1247          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1248       int segmentSpec = std::get<1>(it.value());
1249       if (segmentSpec == 1 || segmentSpec == 0) {
1250         // Unpack unary element.
1251         try {
1252           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1253           if (operandValue) {
1254             operands.push_back(operandValue);
1255             operandSegmentLengths.push_back(1);
1256           } else if (segmentSpec == 0) {
1257             // Allowed to be optional.
1258             operandSegmentLengths.push_back(0);
1259           } else {
1260             throw py::cast_error("was None and operand is not optional");
1261           }
1262         } catch (py::cast_error &err) {
1263           throw py::value_error((llvm::Twine("Operand ") +
1264                                  llvm::Twine(it.index()) + " of operation \"" +
1265                                  name + "\" must be a Value (" + err.what() +
1266                                  ")")
1267                                     .str());
1268         }
1269       } else if (segmentSpec == -1) {
1270         // Unpack sequence by appending.
1271         try {
1272           if (std::get<0>(it.value()).is_none()) {
1273             // Treat it as an empty list.
1274             operandSegmentLengths.push_back(0);
1275           } else {
1276             // Unpack the list.
1277             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1278             for (py::object segmentItem : segment) {
1279               operands.push_back(py::cast<PyValue *>(segmentItem));
1280               if (!operands.back()) {
1281                 throw py::cast_error("contained a None item");
1282               }
1283             }
1284             operandSegmentLengths.push_back(segment.size());
1285           }
1286         } catch (std::exception &err) {
1287           // NOTE: Sloppy to be using a catch-all here, but there are at least
1288           // three different unrelated exceptions that can be thrown in the
1289           // above "casts". Just keep the scope above small and catch them all.
1290           throw py::value_error((llvm::Twine("Operand ") +
1291                                  llvm::Twine(it.index()) + " of operation \"" +
1292                                  name + "\" must be a Sequence of Values (" +
1293                                  err.what() + ")")
1294                                     .str());
1295         }
1296       } else {
1297         throw py::value_error("Unexpected segment spec");
1298       }
1299     }
1300   }
1301 
1302   // Merge operand/result segment lengths into attributes if needed.
1303   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1304     // Dup.
1305     if (attributes) {
1306       attributes = py::dict(*attributes);
1307     } else {
1308       attributes = py::dict();
1309     }
1310     if (attributes->contains("result_segment_sizes") ||
1311         attributes->contains("operand_segment_sizes")) {
1312       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1313                             "'operand_segment_sizes' attribute is unsupported. "
1314                             "Use Operation.create for such low-level access.");
1315     }
1316 
1317     // Add result_segment_sizes attribute.
1318     if (!resultSegmentLengths.empty()) {
1319       int64_t size = resultSegmentLengths.size();
1320       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1321           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1322           resultSegmentLengths.size(), resultSegmentLengths.data());
1323       (*attributes)["result_segment_sizes"] =
1324           PyAttribute(context, segmentLengthAttr);
1325     }
1326 
1327     // Add operand_segment_sizes attribute.
1328     if (!operandSegmentLengths.empty()) {
1329       int64_t size = operandSegmentLengths.size();
1330       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1331           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1332           operandSegmentLengths.size(), operandSegmentLengths.data());
1333       (*attributes)["operand_segment_sizes"] =
1334           PyAttribute(context, segmentLengthAttr);
1335     }
1336   }
1337 
1338   // Delegate to create.
1339   return PyOperation::create(std::move(name),
1340                              /*results=*/std::move(resultTypes),
1341                              /*operands=*/std::move(operands),
1342                              /*attributes=*/std::move(attributes),
1343                              /*successors=*/std::move(successors),
1344                              /*regions=*/*regions, location, maybeIp);
1345 }
1346 
1347 PyOpView::PyOpView(py::object operationObject)
1348     // Casting through the PyOperationBase base-class and then back to the
1349     // Operation lets us accept any PyOperationBase subclass.
1350     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1351       operationObject(operation.getRef().getObject()) {}
1352 
1353 py::object PyOpView::createRawSubclass(py::object userClass) {
1354   // This is... a little gross. The typical pattern is to have a pure python
1355   // class that extends OpView like:
1356   //   class AddFOp(_cext.ir.OpView):
1357   //     def __init__(self, loc, lhs, rhs):
1358   //       operation = loc.context.create_operation(
1359   //           "addf", lhs, rhs, results=[lhs.type])
1360   //       super().__init__(operation)
1361   //
1362   // I.e. The goal of the user facing type is to provide a nice constructor
1363   // that has complete freedom for the op under construction. This is at odds
1364   // with our other desire to sometimes create this object by just passing an
1365   // operation (to initialize the base class). We could do *arg and **kwargs
1366   // munging to try to make it work, but instead, we synthesize a new class
1367   // on the fly which extends this user class (AddFOp in this example) and
1368   // *give it* the base class's __init__ method, thus bypassing the
1369   // intermediate subclass's __init__ method entirely. While slightly,
1370   // underhanded, this is safe/legal because the type hierarchy has not changed
1371   // (we just added a new leaf) and we aren't mucking around with __new__.
1372   // Typically, this new class will be stored on the original as "_Raw" and will
1373   // be used for casts and other things that need a variant of the class that
1374   // is initialized purely from an operation.
1375   py::object parentMetaclass =
1376       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1377   py::dict attributes;
1378   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1379   // now.
1380   //   auto opViewType = py::type::of<PyOpView>();
1381   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1382   attributes["__init__"] = opViewType.attr("__init__");
1383   py::str origName = userClass.attr("__name__");
1384   py::str newName = py::str("_") + origName;
1385   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1386 }
1387 
1388 //------------------------------------------------------------------------------
1389 // PyInsertionPoint.
1390 //------------------------------------------------------------------------------
1391 
1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1393 
1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1395     : refOperation(beforeOperationBase.getOperation().getRef()),
1396       block((*refOperation)->getBlock()) {}
1397 
1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1399   PyOperation &operation = operationBase.getOperation();
1400   if (operation.isAttached())
1401     throw SetPyError(PyExc_ValueError,
1402                      "Attempt to insert operation that is already attached");
1403   block.getParentOperation()->checkValid();
1404   MlirOperation beforeOp = {nullptr};
1405   if (refOperation) {
1406     // Insert before operation.
1407     (*refOperation)->checkValid();
1408     beforeOp = (*refOperation)->get();
1409   } else {
1410     // Insert at end (before null) is only valid if the block does not
1411     // already end in a known terminator (violating this will cause assertion
1412     // failures later).
1413     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1414       throw py::index_error("Cannot insert operation at the end of a block "
1415                             "that already has a terminator. Did you mean to "
1416                             "use 'InsertionPoint.at_block_terminator(block)' "
1417                             "versus 'InsertionPoint(block)'?");
1418     }
1419   }
1420   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1421   operation.setAttached();
1422 }
1423 
1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1425   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1426   if (mlirOperationIsNull(firstOp)) {
1427     // Just insert at end.
1428     return PyInsertionPoint(block);
1429   }
1430 
1431   // Insert before first op.
1432   PyOperationRef firstOpRef = PyOperation::forOperation(
1433       block.getParentOperation()->getContext(), firstOp);
1434   return PyInsertionPoint{block, std::move(firstOpRef)};
1435 }
1436 
1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1438   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1439   if (mlirOperationIsNull(terminator))
1440     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1441   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1442       block.getParentOperation()->getContext(), terminator);
1443   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1444 }
1445 
1446 py::object PyInsertionPoint::contextEnter() {
1447   return PyThreadContextEntry::pushInsertionPoint(*this);
1448 }
1449 
1450 void PyInsertionPoint::contextExit(pybind11::object excType,
1451                                    pybind11::object excVal,
1452                                    pybind11::object excTb) {
1453   PyThreadContextEntry::popInsertionPoint(*this);
1454 }
1455 
1456 //------------------------------------------------------------------------------
1457 // PyAttribute.
1458 //------------------------------------------------------------------------------
1459 
1460 bool PyAttribute::operator==(const PyAttribute &other) {
1461   return mlirAttributeEqual(attr, other.attr);
1462 }
1463 
1464 py::object PyAttribute::getCapsule() {
1465   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1466 }
1467 
1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1469   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1470   if (mlirAttributeIsNull(rawAttr))
1471     throw py::error_already_set();
1472   return PyAttribute(
1473       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1474 }
1475 
1476 //------------------------------------------------------------------------------
1477 // PyNamedAttribute.
1478 //------------------------------------------------------------------------------
1479 
1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1481     : ownedName(new std::string(std::move(ownedName))) {
1482   namedAttr = mlirNamedAttributeGet(
1483       mlirIdentifierGet(mlirAttributeGetContext(attr),
1484                         toMlirStringRef(*this->ownedName)),
1485       attr);
1486 }
1487 
1488 //------------------------------------------------------------------------------
1489 // PyType.
1490 //------------------------------------------------------------------------------
1491 
1492 bool PyType::operator==(const PyType &other) {
1493   return mlirTypeEqual(type, other.type);
1494 }
1495 
1496 py::object PyType::getCapsule() {
1497   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1498 }
1499 
1500 PyType PyType::createFromCapsule(py::object capsule) {
1501   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1502   if (mlirTypeIsNull(rawType))
1503     throw py::error_already_set();
1504   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1505                 rawType);
1506 }
1507 
1508 //------------------------------------------------------------------------------
1509 // PyValue and subclases.
1510 //------------------------------------------------------------------------------
1511 
1512 pybind11::object PyValue::getCapsule() {
1513   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1514 }
1515 
1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1517   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1518   if (mlirValueIsNull(value))
1519     throw py::error_already_set();
1520   MlirOperation owner;
1521   if (mlirValueIsAOpResult(value))
1522     owner = mlirOpResultGetOwner(value);
1523   if (mlirValueIsABlockArgument(value))
1524     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1525   if (mlirOperationIsNull(owner))
1526     throw py::error_already_set();
1527   MlirContext ctx = mlirOperationGetContext(owner);
1528   PyOperationRef ownerRef =
1529       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1530   return PyValue(ownerRef, value);
1531 }
1532 
1533 //------------------------------------------------------------------------------
1534 // PySymbolTable.
1535 //------------------------------------------------------------------------------
1536 
1537 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1538     : operation(operation.getOperation().getRef()) {
1539   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1540   if (mlirSymbolTableIsNull(symbolTable)) {
1541     throw py::cast_error("Operation is not a Symbol Table.");
1542   }
1543 }
1544 
1545 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1546   operation->checkValid();
1547   MlirOperation symbol = mlirSymbolTableLookup(
1548       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1549   if (mlirOperationIsNull(symbol))
1550     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1551 
1552   return PyOperation::forOperation(operation->getContext(), symbol,
1553                                    operation.getObject())
1554       ->createOpView();
1555 }
1556 
1557 void PySymbolTable::erase(PyOperationBase &symbol) {
1558   operation->checkValid();
1559   symbol.getOperation().checkValid();
1560   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1561   // The operation is also erased, so we must invalidate it. There may be Python
1562   // references to this operation so we don't want to delete it from the list of
1563   // live operations here.
1564   symbol.getOperation().valid = false;
1565 }
1566 
1567 void PySymbolTable::dunderDel(const std::string &name) {
1568   py::object operation = dunderGetItem(name);
1569   erase(py::cast<PyOperationBase &>(operation));
1570 }
1571 
1572 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1573   operation->checkValid();
1574   symbol.getOperation().checkValid();
1575   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1576       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1577   if (mlirAttributeIsNull(symbolAttr))
1578     throw py::value_error("Expected operation to have a symbol name.");
1579   return PyAttribute(
1580       symbol.getOperation().getContext(),
1581       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1582 }
1583 
1584 namespace {
1585 /// CRTP base class for Python MLIR values that subclass Value and should be
1586 /// castable from it. The value hierarchy is one level deep and is not supposed
1587 /// to accommodate other levels unless core MLIR changes.
1588 template <typename DerivedTy>
1589 class PyConcreteValue : public PyValue {
1590 public:
1591   // Derived classes must define statics for:
1592   //   IsAFunctionTy isaFunction
1593   //   const char *pyClassName
1594   // and redefine bindDerived.
1595   using ClassTy = py::class_<DerivedTy, PyValue>;
1596   using IsAFunctionTy = bool (*)(MlirValue);
1597 
1598   PyConcreteValue() = default;
1599   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1600       : PyValue(operationRef, value) {}
1601   PyConcreteValue(PyValue &orig)
1602       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1603 
1604   /// Attempts to cast the original value to the derived type and throws on
1605   /// type mismatches.
1606   static MlirValue castFrom(PyValue &orig) {
1607     if (!DerivedTy::isaFunction(orig.get())) {
1608       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1609       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1610                                              DerivedTy::pyClassName +
1611                                              " (from " + origRepr + ")");
1612     }
1613     return orig.get();
1614   }
1615 
1616   /// Binds the Python module objects to functions of this class.
1617   static void bind(py::module &m) {
1618     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1619     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1620     cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
1621       return DerivedTy::isaFunction(otherValue);
1622     });
1623     DerivedTy::bindDerived(cls);
1624   }
1625 
1626   /// Implemented by derived classes to add methods to the Python subclass.
1627   static void bindDerived(ClassTy &m) {}
1628 };
1629 
1630 /// Python wrapper for MlirBlockArgument.
1631 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1632 public:
1633   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1634   static constexpr const char *pyClassName = "BlockArgument";
1635   using PyConcreteValue::PyConcreteValue;
1636 
1637   static void bindDerived(ClassTy &c) {
1638     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1639       return PyBlock(self.getParentOperation(),
1640                      mlirBlockArgumentGetOwner(self.get()));
1641     });
1642     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1643       return mlirBlockArgumentGetArgNumber(self.get());
1644     });
1645     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1646       return mlirBlockArgumentSetType(self.get(), type);
1647     });
1648   }
1649 };
1650 
1651 /// Python wrapper for MlirOpResult.
1652 class PyOpResult : public PyConcreteValue<PyOpResult> {
1653 public:
1654   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1655   static constexpr const char *pyClassName = "OpResult";
1656   using PyConcreteValue::PyConcreteValue;
1657 
1658   static void bindDerived(ClassTy &c) {
1659     c.def_property_readonly("owner", [](PyOpResult &self) {
1660       assert(
1661           mlirOperationEqual(self.getParentOperation()->get(),
1662                              mlirOpResultGetOwner(self.get())) &&
1663           "expected the owner of the value in Python to match that in the IR");
1664       return self.getParentOperation().getObject();
1665     });
1666     c.def_property_readonly("result_number", [](PyOpResult &self) {
1667       return mlirOpResultGetResultNumber(self.get());
1668     });
1669   }
1670 };
1671 
1672 /// Returns the list of types of the values held by container.
1673 template <typename Container>
1674 static std::vector<PyType> getValueTypes(Container &container,
1675                                          PyMlirContextRef &context) {
1676   std::vector<PyType> result;
1677   result.reserve(container.getNumElements());
1678   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1679     result.push_back(
1680         PyType(context, mlirValueGetType(container.getElement(i).get())));
1681   }
1682   return result;
1683 }
1684 
1685 /// A list of block arguments. Internally, these are stored as consecutive
1686 /// elements, random access is cheap. The argument list is associated with the
1687 /// operation that contains the block (detached blocks are not allowed in
1688 /// Python bindings) and extends its lifetime.
1689 class PyBlockArgumentList
1690     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1691 public:
1692   static constexpr const char *pyClassName = "BlockArgumentList";
1693 
1694   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1695                       intptr_t startIndex = 0, intptr_t length = -1,
1696                       intptr_t step = 1)
1697       : Sliceable(startIndex,
1698                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1699                   step),
1700         operation(std::move(operation)), block(block) {}
1701 
1702   /// Returns the number of arguments in the list.
1703   intptr_t getNumElements() {
1704     operation->checkValid();
1705     return mlirBlockGetNumArguments(block);
1706   }
1707 
1708   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1709   PyBlockArgument getElement(intptr_t pos) {
1710     MlirValue argument = mlirBlockGetArgument(block, pos);
1711     return PyBlockArgument(operation, argument);
1712   }
1713 
1714   /// Returns a sublist of this list.
1715   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1716                             intptr_t step) {
1717     return PyBlockArgumentList(operation, block, startIndex, length, step);
1718   }
1719 
1720   static void bindDerived(ClassTy &c) {
1721     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1722       return getValueTypes(self, self.operation->getContext());
1723     });
1724   }
1725 
1726 private:
1727   PyOperationRef operation;
1728   MlirBlock block;
1729 };
1730 
1731 /// A list of operation operands. Internally, these are stored as consecutive
1732 /// elements, random access is cheap. The result list is associated with the
1733 /// operation whose results these are, and extends the lifetime of this
1734 /// operation.
1735 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1736 public:
1737   static constexpr const char *pyClassName = "OpOperandList";
1738 
1739   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1740                   intptr_t length = -1, intptr_t step = 1)
1741       : Sliceable(startIndex,
1742                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1743                                : length,
1744                   step),
1745         operation(operation) {}
1746 
1747   intptr_t getNumElements() {
1748     operation->checkValid();
1749     return mlirOperationGetNumOperands(operation->get());
1750   }
1751 
1752   PyValue getElement(intptr_t pos) {
1753     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1754     MlirOperation owner;
1755     if (mlirValueIsAOpResult(operand))
1756       owner = mlirOpResultGetOwner(operand);
1757     else if (mlirValueIsABlockArgument(operand))
1758       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1759     else
1760       assert(false && "Value must be an block arg or op result.");
1761     PyOperationRef pyOwner =
1762         PyOperation::forOperation(operation->getContext(), owner);
1763     return PyValue(pyOwner, operand);
1764   }
1765 
1766   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1767     return PyOpOperandList(operation, startIndex, length, step);
1768   }
1769 
1770   void dunderSetItem(intptr_t index, PyValue value) {
1771     index = wrapIndex(index);
1772     mlirOperationSetOperand(operation->get(), index, value.get());
1773   }
1774 
1775   static void bindDerived(ClassTy &c) {
1776     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1777   }
1778 
1779 private:
1780   PyOperationRef operation;
1781 };
1782 
1783 /// A list of operation results. Internally, these are stored as consecutive
1784 /// elements, random access is cheap. The result list is associated with the
1785 /// operation whose results these are, and extends the lifetime of this
1786 /// operation.
1787 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1788 public:
1789   static constexpr const char *pyClassName = "OpResultList";
1790 
1791   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1792                  intptr_t length = -1, intptr_t step = 1)
1793       : Sliceable(startIndex,
1794                   length == -1 ? mlirOperationGetNumResults(operation->get())
1795                                : length,
1796                   step),
1797         operation(operation) {}
1798 
1799   intptr_t getNumElements() {
1800     operation->checkValid();
1801     return mlirOperationGetNumResults(operation->get());
1802   }
1803 
1804   PyOpResult getElement(intptr_t index) {
1805     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1806     return PyOpResult(value);
1807   }
1808 
1809   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1810     return PyOpResultList(operation, startIndex, length, step);
1811   }
1812 
1813   static void bindDerived(ClassTy &c) {
1814     c.def_property_readonly("types", [](PyOpResultList &self) {
1815       return getValueTypes(self, self.operation->getContext());
1816     });
1817   }
1818 
1819 private:
1820   PyOperationRef operation;
1821 };
1822 
1823 /// A list of operation attributes. Can be indexed by name, producing
1824 /// attributes, or by index, producing named attributes.
1825 class PyOpAttributeMap {
1826 public:
1827   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1828 
1829   PyAttribute dunderGetItemNamed(const std::string &name) {
1830     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1831                                                          toMlirStringRef(name));
1832     if (mlirAttributeIsNull(attr)) {
1833       throw SetPyError(PyExc_KeyError,
1834                        "attempt to access a non-existent attribute");
1835     }
1836     return PyAttribute(operation->getContext(), attr);
1837   }
1838 
1839   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1840     if (index < 0 || index >= dunderLen()) {
1841       throw SetPyError(PyExc_IndexError,
1842                        "attempt to access out of bounds attribute");
1843     }
1844     MlirNamedAttribute namedAttr =
1845         mlirOperationGetAttribute(operation->get(), index);
1846     return PyNamedAttribute(
1847         namedAttr.attribute,
1848         std::string(mlirIdentifierStr(namedAttr.name).data,
1849                     mlirIdentifierStr(namedAttr.name).length));
1850   }
1851 
1852   void dunderSetItem(const std::string &name, PyAttribute attr) {
1853     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1854                                     attr);
1855   }
1856 
1857   void dunderDelItem(const std::string &name) {
1858     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1859                                                      toMlirStringRef(name));
1860     if (!removed)
1861       throw SetPyError(PyExc_KeyError,
1862                        "attempt to delete a non-existent attribute");
1863   }
1864 
1865   intptr_t dunderLen() {
1866     return mlirOperationGetNumAttributes(operation->get());
1867   }
1868 
1869   bool dunderContains(const std::string &name) {
1870     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1871         operation->get(), toMlirStringRef(name)));
1872   }
1873 
1874   static void bind(py::module &m) {
1875     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1876         .def("__contains__", &PyOpAttributeMap::dunderContains)
1877         .def("__len__", &PyOpAttributeMap::dunderLen)
1878         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1879         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1880         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1881         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1882   }
1883 
1884 private:
1885   PyOperationRef operation;
1886 };
1887 
1888 } // end namespace
1889 
1890 //------------------------------------------------------------------------------
1891 // Populates the core exports of the 'ir' submodule.
1892 //------------------------------------------------------------------------------
1893 
1894 void mlir::python::populateIRCore(py::module &m) {
1895   //----------------------------------------------------------------------------
1896   // Mapping of MlirContext.
1897   //----------------------------------------------------------------------------
1898   py::class_<PyMlirContext>(m, "Context", py::module_local())
1899       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1900       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1901       .def("_get_context_again",
1902            [](PyMlirContext &self) {
1903              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1904              return ref.releaseObject();
1905            })
1906       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1907       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1908       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1909                              &PyMlirContext::getCapsule)
1910       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1911       .def("__enter__", &PyMlirContext::contextEnter)
1912       .def("__exit__", &PyMlirContext::contextExit)
1913       .def_property_readonly_static(
1914           "current",
1915           [](py::object & /*class*/) {
1916             auto *context = PyThreadContextEntry::getDefaultContext();
1917             if (!context)
1918               throw SetPyError(PyExc_ValueError, "No current Context");
1919             return context;
1920           },
1921           "Gets the Context bound to the current thread or raises ValueError")
1922       .def_property_readonly(
1923           "dialects",
1924           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1925           "Gets a container for accessing dialects by name")
1926       .def_property_readonly(
1927           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1928           "Alias for 'dialect'")
1929       .def(
1930           "get_dialect_descriptor",
1931           [=](PyMlirContext &self, std::string &name) {
1932             MlirDialect dialect = mlirContextGetOrLoadDialect(
1933                 self.get(), {name.data(), name.size()});
1934             if (mlirDialectIsNull(dialect)) {
1935               throw SetPyError(PyExc_ValueError,
1936                                Twine("Dialect '") + name + "' not found");
1937             }
1938             return PyDialectDescriptor(self.getRef(), dialect);
1939           },
1940           "Gets or loads a dialect by name, returning its descriptor object")
1941       .def_property(
1942           "allow_unregistered_dialects",
1943           [](PyMlirContext &self) -> bool {
1944             return mlirContextGetAllowUnregisteredDialects(self.get());
1945           },
1946           [](PyMlirContext &self, bool value) {
1947             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1948           })
1949       .def("enable_multithreading",
1950            [](PyMlirContext &self, bool enable) {
1951              mlirContextEnableMultithreading(self.get(), enable);
1952            })
1953       .def("is_registered_operation",
1954            [](PyMlirContext &self, std::string &name) {
1955              return mlirContextIsRegisteredOperation(
1956                  self.get(), MlirStringRef{name.data(), name.size()});
1957            });
1958 
1959   //----------------------------------------------------------------------------
1960   // Mapping of PyDialectDescriptor
1961   //----------------------------------------------------------------------------
1962   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1963       .def_property_readonly("namespace",
1964                              [](PyDialectDescriptor &self) {
1965                                MlirStringRef ns =
1966                                    mlirDialectGetNamespace(self.get());
1967                                return py::str(ns.data, ns.length);
1968                              })
1969       .def("__repr__", [](PyDialectDescriptor &self) {
1970         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1971         std::string repr("<DialectDescriptor ");
1972         repr.append(ns.data, ns.length);
1973         repr.append(">");
1974         return repr;
1975       });
1976 
1977   //----------------------------------------------------------------------------
1978   // Mapping of PyDialects
1979   //----------------------------------------------------------------------------
1980   py::class_<PyDialects>(m, "Dialects", py::module_local())
1981       .def("__getitem__",
1982            [=](PyDialects &self, std::string keyName) {
1983              MlirDialect dialect =
1984                  self.getDialectForKey(keyName, /*attrError=*/false);
1985              py::object descriptor =
1986                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1987              return createCustomDialectWrapper(keyName, std::move(descriptor));
1988            })
1989       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1990         MlirDialect dialect =
1991             self.getDialectForKey(attrName, /*attrError=*/true);
1992         py::object descriptor =
1993             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1994         return createCustomDialectWrapper(attrName, std::move(descriptor));
1995       });
1996 
1997   //----------------------------------------------------------------------------
1998   // Mapping of PyDialect
1999   //----------------------------------------------------------------------------
2000   py::class_<PyDialect>(m, "Dialect", py::module_local())
2001       .def(py::init<py::object>(), "descriptor")
2002       .def_property_readonly(
2003           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2004       .def("__repr__", [](py::object self) {
2005         auto clazz = self.attr("__class__");
2006         return py::str("<Dialect ") +
2007                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2008                clazz.attr("__module__") + py::str(".") +
2009                clazz.attr("__name__") + py::str(")>");
2010       });
2011 
2012   //----------------------------------------------------------------------------
2013   // Mapping of Location
2014   //----------------------------------------------------------------------------
2015   py::class_<PyLocation>(m, "Location", py::module_local())
2016       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2017       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2018       .def("__enter__", &PyLocation::contextEnter)
2019       .def("__exit__", &PyLocation::contextExit)
2020       .def("__eq__",
2021            [](PyLocation &self, PyLocation &other) -> bool {
2022              return mlirLocationEqual(self, other);
2023            })
2024       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2025       .def_property_readonly_static(
2026           "current",
2027           [](py::object & /*class*/) {
2028             auto *loc = PyThreadContextEntry::getDefaultLocation();
2029             if (!loc)
2030               throw SetPyError(PyExc_ValueError, "No current Location");
2031             return loc;
2032           },
2033           "Gets the Location bound to the current thread or raises ValueError")
2034       .def_static(
2035           "unknown",
2036           [](DefaultingPyMlirContext context) {
2037             return PyLocation(context->getRef(),
2038                               mlirLocationUnknownGet(context->get()));
2039           },
2040           py::arg("context") = py::none(),
2041           "Gets a Location representing an unknown location")
2042       .def_static(
2043           "callsite",
2044           [](PyLocation callee, const std::vector<PyLocation> &frames,
2045              DefaultingPyMlirContext context) {
2046             if (frames.empty())
2047               throw py::value_error("No caller frames provided");
2048             MlirLocation caller = frames.back().get();
2049             for (const PyLocation &frame :
2050                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2051               caller = mlirLocationCallSiteGet(frame.get(), caller);
2052             return PyLocation(context->getRef(),
2053                               mlirLocationCallSiteGet(callee.get(), caller));
2054           },
2055           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2056           kContextGetCallSiteLocationDocstring)
2057       .def_static(
2058           "file",
2059           [](std::string filename, int line, int col,
2060              DefaultingPyMlirContext context) {
2061             return PyLocation(
2062                 context->getRef(),
2063                 mlirLocationFileLineColGet(
2064                     context->get(), toMlirStringRef(filename), line, col));
2065           },
2066           py::arg("filename"), py::arg("line"), py::arg("col"),
2067           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2068       .def_static(
2069           "name",
2070           [](std::string name, llvm::Optional<PyLocation> childLoc,
2071              DefaultingPyMlirContext context) {
2072             return PyLocation(
2073                 context->getRef(),
2074                 mlirLocationNameGet(
2075                     context->get(), toMlirStringRef(name),
2076                     childLoc ? childLoc->get()
2077                              : mlirLocationUnknownGet(context->get())));
2078           },
2079           py::arg("name"), py::arg("childLoc") = py::none(),
2080           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2081       .def_property_readonly(
2082           "context",
2083           [](PyLocation &self) { return self.getContext().getObject(); },
2084           "Context that owns the Location")
2085       .def("__repr__", [](PyLocation &self) {
2086         PyPrintAccumulator printAccum;
2087         mlirLocationPrint(self, printAccum.getCallback(),
2088                           printAccum.getUserData());
2089         return printAccum.join();
2090       });
2091 
2092   //----------------------------------------------------------------------------
2093   // Mapping of Module
2094   //----------------------------------------------------------------------------
2095   py::class_<PyModule>(m, "Module", py::module_local())
2096       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2097       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2098       .def_static(
2099           "parse",
2100           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2101             MlirModule module = mlirModuleCreateParse(
2102                 context->get(), toMlirStringRef(moduleAsm));
2103             // TODO: Rework error reporting once diagnostic engine is exposed
2104             // in C API.
2105             if (mlirModuleIsNull(module)) {
2106               throw SetPyError(
2107                   PyExc_ValueError,
2108                   "Unable to parse module assembly (see diagnostics)");
2109             }
2110             return PyModule::forModule(module).releaseObject();
2111           },
2112           py::arg("asm"), py::arg("context") = py::none(),
2113           kModuleParseDocstring)
2114       .def_static(
2115           "create",
2116           [](DefaultingPyLocation loc) {
2117             MlirModule module = mlirModuleCreateEmpty(loc);
2118             return PyModule::forModule(module).releaseObject();
2119           },
2120           py::arg("loc") = py::none(), "Creates an empty module")
2121       .def_property_readonly(
2122           "context",
2123           [](PyModule &self) { return self.getContext().getObject(); },
2124           "Context that created the Module")
2125       .def_property_readonly(
2126           "operation",
2127           [](PyModule &self) {
2128             return PyOperation::forOperation(self.getContext(),
2129                                              mlirModuleGetOperation(self.get()),
2130                                              self.getRef().releaseObject())
2131                 .releaseObject();
2132           },
2133           "Accesses the module as an operation")
2134       .def_property_readonly(
2135           "body",
2136           [](PyModule &self) {
2137             PyOperationRef module_op = PyOperation::forOperation(
2138                 self.getContext(), mlirModuleGetOperation(self.get()),
2139                 self.getRef().releaseObject());
2140             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2141             return returnBlock;
2142           },
2143           "Return the block for this module")
2144       .def(
2145           "dump",
2146           [](PyModule &self) {
2147             mlirOperationDump(mlirModuleGetOperation(self.get()));
2148           },
2149           kDumpDocstring)
2150       .def(
2151           "__str__",
2152           [](PyModule &self) {
2153             MlirOperation operation = mlirModuleGetOperation(self.get());
2154             PyPrintAccumulator printAccum;
2155             mlirOperationPrint(operation, printAccum.getCallback(),
2156                                printAccum.getUserData());
2157             return printAccum.join();
2158           },
2159           kOperationStrDunderDocstring);
2160 
2161   //----------------------------------------------------------------------------
2162   // Mapping of Operation.
2163   //----------------------------------------------------------------------------
2164   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2165       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2166                              [](PyOperationBase &self) {
2167                                return self.getOperation().getCapsule();
2168                              })
2169       .def("__eq__",
2170            [](PyOperationBase &self, PyOperationBase &other) {
2171              return &self.getOperation() == &other.getOperation();
2172            })
2173       .def("__eq__",
2174            [](PyOperationBase &self, py::object other) { return false; })
2175       .def("__hash__",
2176            [](PyOperationBase &self) {
2177              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2178            })
2179       .def_property_readonly("attributes",
2180                              [](PyOperationBase &self) {
2181                                return PyOpAttributeMap(
2182                                    self.getOperation().getRef());
2183                              })
2184       .def_property_readonly("operands",
2185                              [](PyOperationBase &self) {
2186                                return PyOpOperandList(
2187                                    self.getOperation().getRef());
2188                              })
2189       .def_property_readonly("regions",
2190                              [](PyOperationBase &self) {
2191                                return PyRegionList(
2192                                    self.getOperation().getRef());
2193                              })
2194       .def_property_readonly(
2195           "results",
2196           [](PyOperationBase &self) {
2197             return PyOpResultList(self.getOperation().getRef());
2198           },
2199           "Returns the list of Operation results.")
2200       .def_property_readonly(
2201           "result",
2202           [](PyOperationBase &self) {
2203             auto &operation = self.getOperation();
2204             auto numResults = mlirOperationGetNumResults(operation);
2205             if (numResults != 1) {
2206               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2207               throw SetPyError(
2208                   PyExc_ValueError,
2209                   Twine("Cannot call .result on operation ") +
2210                       StringRef(name.data, name.length) + " which has " +
2211                       Twine(numResults) +
2212                       " results (it is only valid for operations with a "
2213                       "single result)");
2214             }
2215             return PyOpResult(operation.getRef(),
2216                               mlirOperationGetResult(operation, 0));
2217           },
2218           "Shortcut to get an op result if it has only one (throws an error "
2219           "otherwise).")
2220       .def_property_readonly(
2221           "location",
2222           [](PyOperationBase &self) {
2223             PyOperation &operation = self.getOperation();
2224             return PyLocation(operation.getContext(),
2225                               mlirOperationGetLocation(operation.get()));
2226           },
2227           "Returns the source location the operation was defined or derived "
2228           "from.")
2229       .def(
2230           "__str__",
2231           [](PyOperationBase &self) {
2232             return self.getAsm(/*binary=*/false,
2233                                /*largeElementsLimit=*/llvm::None,
2234                                /*enableDebugInfo=*/false,
2235                                /*prettyDebugInfo=*/false,
2236                                /*printGenericOpForm=*/false,
2237                                /*useLocalScope=*/false);
2238           },
2239           "Returns the assembly form of the operation.")
2240       .def("print", &PyOperationBase::print,
2241            // Careful: Lots of arguments must match up with print method.
2242            py::arg("file") = py::none(), py::arg("binary") = false,
2243            py::arg("large_elements_limit") = py::none(),
2244            py::arg("enable_debug_info") = false,
2245            py::arg("pretty_debug_info") = false,
2246            py::arg("print_generic_op_form") = false,
2247            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2248       .def("get_asm", &PyOperationBase::getAsm,
2249            // Careful: Lots of arguments must match up with get_asm method.
2250            py::arg("binary") = false,
2251            py::arg("large_elements_limit") = py::none(),
2252            py::arg("enable_debug_info") = false,
2253            py::arg("pretty_debug_info") = false,
2254            py::arg("print_generic_op_form") = false,
2255            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2256       .def(
2257           "verify",
2258           [](PyOperationBase &self) {
2259             return mlirOperationVerify(self.getOperation());
2260           },
2261           "Verify the operation and return true if it passes, false if it "
2262           "fails.")
2263       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2264            "Puts self immediately after the other operation in its parent "
2265            "block.")
2266       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2267            "Puts self immediately before the other operation in its parent "
2268            "block.")
2269       .def(
2270           "detach_from_parent",
2271           [](PyOperationBase &self) {
2272             PyOperation &operation = self.getOperation();
2273             operation.checkValid();
2274             if (!operation.isAttached())
2275               throw py::value_error("Detached operation has no parent.");
2276 
2277             operation.detachFromParent();
2278             return operation.createOpView();
2279           },
2280           "Detaches the operation from its parent block.");
2281 
2282   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2283       .def_static("create", &PyOperation::create, py::arg("name"),
2284                   py::arg("results") = py::none(),
2285                   py::arg("operands") = py::none(),
2286                   py::arg("attributes") = py::none(),
2287                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2288                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2289                   kOperationCreateDocstring)
2290       .def_property_readonly("parent",
2291                              [](PyOperation &self) -> py::object {
2292                                auto parent = self.getParentOperation();
2293                                if (parent)
2294                                  return parent->getObject();
2295                                return py::none();
2296                              })
2297       .def("erase", &PyOperation::erase)
2298       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2299                              &PyOperation::getCapsule)
2300       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2301       .def_property_readonly("name",
2302                              [](PyOperation &self) {
2303                                self.checkValid();
2304                                MlirOperation operation = self.get();
2305                                MlirStringRef name = mlirIdentifierStr(
2306                                    mlirOperationGetName(operation));
2307                                return py::str(name.data, name.length);
2308                              })
2309       .def_property_readonly(
2310           "context",
2311           [](PyOperation &self) {
2312             self.checkValid();
2313             return self.getContext().getObject();
2314           },
2315           "Context that owns the Operation")
2316       .def_property_readonly("opview", &PyOperation::createOpView);
2317 
2318   auto opViewClass =
2319       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2320           .def(py::init<py::object>())
2321           .def_property_readonly("operation", &PyOpView::getOperationObject)
2322           .def_property_readonly(
2323               "context",
2324               [](PyOpView &self) {
2325                 return self.getOperation().getContext().getObject();
2326               },
2327               "Context that owns the Operation")
2328           .def("__str__", [](PyOpView &self) {
2329             return py::str(self.getOperationObject());
2330           });
2331   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2332   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2333   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2334   opViewClass.attr("build_generic") = classmethod(
2335       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2336       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2337       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2338       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2339       "Builds a specific, generated OpView based on class level attributes.");
2340 
2341   //----------------------------------------------------------------------------
2342   // Mapping of PyRegion.
2343   //----------------------------------------------------------------------------
2344   py::class_<PyRegion>(m, "Region", py::module_local())
2345       .def_property_readonly(
2346           "blocks",
2347           [](PyRegion &self) {
2348             return PyBlockList(self.getParentOperation(), self.get());
2349           },
2350           "Returns a forward-optimized sequence of blocks.")
2351       .def_property_readonly(
2352           "owner",
2353           [](PyRegion &self) {
2354             return self.getParentOperation()->createOpView();
2355           },
2356           "Returns the operation owning this region.")
2357       .def(
2358           "__iter__",
2359           [](PyRegion &self) {
2360             self.checkValid();
2361             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2362             return PyBlockIterator(self.getParentOperation(), firstBlock);
2363           },
2364           "Iterates over blocks in the region.")
2365       .def("__eq__",
2366            [](PyRegion &self, PyRegion &other) {
2367              return self.get().ptr == other.get().ptr;
2368            })
2369       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2370 
2371   //----------------------------------------------------------------------------
2372   // Mapping of PyBlock.
2373   //----------------------------------------------------------------------------
2374   py::class_<PyBlock>(m, "Block", py::module_local())
2375       .def_property_readonly(
2376           "owner",
2377           [](PyBlock &self) {
2378             return self.getParentOperation()->createOpView();
2379           },
2380           "Returns the owning operation of this block.")
2381       .def_property_readonly(
2382           "region",
2383           [](PyBlock &self) {
2384             MlirRegion region = mlirBlockGetParentRegion(self.get());
2385             return PyRegion(self.getParentOperation(), region);
2386           },
2387           "Returns the owning region of this block.")
2388       .def_property_readonly(
2389           "arguments",
2390           [](PyBlock &self) {
2391             return PyBlockArgumentList(self.getParentOperation(), self.get());
2392           },
2393           "Returns a list of block arguments.")
2394       .def_property_readonly(
2395           "operations",
2396           [](PyBlock &self) {
2397             return PyOperationList(self.getParentOperation(), self.get());
2398           },
2399           "Returns a forward-optimized sequence of operations.")
2400       .def_static(
2401           "create_at_start",
2402           [](PyRegion &parent, py::list pyArgTypes) {
2403             parent.checkValid();
2404             llvm::SmallVector<MlirType, 4> argTypes;
2405             argTypes.reserve(pyArgTypes.size());
2406             for (auto &pyArg : pyArgTypes) {
2407               argTypes.push_back(pyArg.cast<PyType &>());
2408             }
2409 
2410             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2411             mlirRegionInsertOwnedBlock(parent, 0, block);
2412             return PyBlock(parent.getParentOperation(), block);
2413           },
2414           py::arg("parent"), py::arg("pyArgTypes") = py::list(),
2415           "Creates and returns a new Block at the beginning of the given "
2416           "region (with given argument types).")
2417       .def(
2418           "create_before",
2419           [](PyBlock &self, py::args pyArgTypes) {
2420             self.checkValid();
2421             llvm::SmallVector<MlirType, 4> argTypes;
2422             argTypes.reserve(pyArgTypes.size());
2423             for (auto &pyArg : pyArgTypes) {
2424               argTypes.push_back(pyArg.cast<PyType &>());
2425             }
2426 
2427             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2428             MlirRegion region = mlirBlockGetParentRegion(self.get());
2429             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2430             return PyBlock(self.getParentOperation(), block);
2431           },
2432           "Creates and returns a new Block before this block "
2433           "(with given argument types).")
2434       .def(
2435           "create_after",
2436           [](PyBlock &self, py::args pyArgTypes) {
2437             self.checkValid();
2438             llvm::SmallVector<MlirType, 4> argTypes;
2439             argTypes.reserve(pyArgTypes.size());
2440             for (auto &pyArg : pyArgTypes) {
2441               argTypes.push_back(pyArg.cast<PyType &>());
2442             }
2443 
2444             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2445             MlirRegion region = mlirBlockGetParentRegion(self.get());
2446             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2447             return PyBlock(self.getParentOperation(), block);
2448           },
2449           "Creates and returns a new Block after this block "
2450           "(with given argument types).")
2451       .def(
2452           "__iter__",
2453           [](PyBlock &self) {
2454             self.checkValid();
2455             MlirOperation firstOperation =
2456                 mlirBlockGetFirstOperation(self.get());
2457             return PyOperationIterator(self.getParentOperation(),
2458                                        firstOperation);
2459           },
2460           "Iterates over operations in the block.")
2461       .def("__eq__",
2462            [](PyBlock &self, PyBlock &other) {
2463              return self.get().ptr == other.get().ptr;
2464            })
2465       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2466       .def(
2467           "__str__",
2468           [](PyBlock &self) {
2469             self.checkValid();
2470             PyPrintAccumulator printAccum;
2471             mlirBlockPrint(self.get(), printAccum.getCallback(),
2472                            printAccum.getUserData());
2473             return printAccum.join();
2474           },
2475           "Returns the assembly form of the block.")
2476       .def(
2477           "append",
2478           [](PyBlock &self, PyOperationBase &operation) {
2479             if (operation.getOperation().isAttached())
2480               operation.getOperation().detachFromParent();
2481 
2482             MlirOperation mlirOperation = operation.getOperation().get();
2483             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2484             operation.getOperation().setAttached(
2485                 self.getParentOperation().getObject());
2486           },
2487           "Appends an operation to this block. If the operation is currently "
2488           "in another block, it will be moved.");
2489 
2490   //----------------------------------------------------------------------------
2491   // Mapping of PyInsertionPoint.
2492   //----------------------------------------------------------------------------
2493 
2494   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2495       .def(py::init<PyBlock &>(), py::arg("block"),
2496            "Inserts after the last operation but still inside the block.")
2497       .def("__enter__", &PyInsertionPoint::contextEnter)
2498       .def("__exit__", &PyInsertionPoint::contextExit)
2499       .def_property_readonly_static(
2500           "current",
2501           [](py::object & /*class*/) {
2502             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2503             if (!ip)
2504               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2505             return ip;
2506           },
2507           "Gets the InsertionPoint bound to the current thread or raises "
2508           "ValueError if none has been set")
2509       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2510            "Inserts before a referenced operation.")
2511       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2512                   py::arg("block"), "Inserts at the beginning of the block.")
2513       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2514                   py::arg("block"), "Inserts before the block terminator.")
2515       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2516            "Inserts an operation.")
2517       .def_property_readonly(
2518           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2519           "Returns the block that this InsertionPoint points to.");
2520 
2521   //----------------------------------------------------------------------------
2522   // Mapping of PyAttribute.
2523   //----------------------------------------------------------------------------
2524   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2525       // Delegate to the PyAttribute copy constructor, which will also lifetime
2526       // extend the backing context which owns the MlirAttribute.
2527       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2528            "Casts the passed attribute to the generic Attribute")
2529       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2530                              &PyAttribute::getCapsule)
2531       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2532       .def_static(
2533           "parse",
2534           [](std::string attrSpec, DefaultingPyMlirContext context) {
2535             MlirAttribute type = mlirAttributeParseGet(
2536                 context->get(), toMlirStringRef(attrSpec));
2537             // TODO: Rework error reporting once diagnostic engine is exposed
2538             // in C API.
2539             if (mlirAttributeIsNull(type)) {
2540               throw SetPyError(PyExc_ValueError,
2541                                Twine("Unable to parse attribute: '") +
2542                                    attrSpec + "'");
2543             }
2544             return PyAttribute(context->getRef(), type);
2545           },
2546           py::arg("asm"), py::arg("context") = py::none(),
2547           "Parses an attribute from an assembly form")
2548       .def_property_readonly(
2549           "context",
2550           [](PyAttribute &self) { return self.getContext().getObject(); },
2551           "Context that owns the Attribute")
2552       .def_property_readonly("type",
2553                              [](PyAttribute &self) {
2554                                return PyType(self.getContext()->getRef(),
2555                                              mlirAttributeGetType(self));
2556                              })
2557       .def(
2558           "get_named",
2559           [](PyAttribute &self, std::string name) {
2560             return PyNamedAttribute(self, std::move(name));
2561           },
2562           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2563       .def("__eq__",
2564            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2565       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2566       .def("__hash__",
2567            [](PyAttribute &self) {
2568              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2569            })
2570       .def(
2571           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2572           kDumpDocstring)
2573       .def(
2574           "__str__",
2575           [](PyAttribute &self) {
2576             PyPrintAccumulator printAccum;
2577             mlirAttributePrint(self, printAccum.getCallback(),
2578                                printAccum.getUserData());
2579             return printAccum.join();
2580           },
2581           "Returns the assembly form of the Attribute.")
2582       .def("__repr__", [](PyAttribute &self) {
2583         // Generally, assembly formats are not printed for __repr__ because
2584         // this can cause exceptionally long debug output and exceptions.
2585         // However, attribute values are generally considered useful and are
2586         // printed. This may need to be re-evaluated if debug dumps end up
2587         // being excessive.
2588         PyPrintAccumulator printAccum;
2589         printAccum.parts.append("Attribute(");
2590         mlirAttributePrint(self, printAccum.getCallback(),
2591                            printAccum.getUserData());
2592         printAccum.parts.append(")");
2593         return printAccum.join();
2594       });
2595 
2596   //----------------------------------------------------------------------------
2597   // Mapping of PyNamedAttribute
2598   //----------------------------------------------------------------------------
2599   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2600       .def("__repr__",
2601            [](PyNamedAttribute &self) {
2602              PyPrintAccumulator printAccum;
2603              printAccum.parts.append("NamedAttribute(");
2604              printAccum.parts.append(
2605                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2606                          mlirIdentifierStr(self.namedAttr.name).length));
2607              printAccum.parts.append("=");
2608              mlirAttributePrint(self.namedAttr.attribute,
2609                                 printAccum.getCallback(),
2610                                 printAccum.getUserData());
2611              printAccum.parts.append(")");
2612              return printAccum.join();
2613            })
2614       .def_property_readonly(
2615           "name",
2616           [](PyNamedAttribute &self) {
2617             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2618                            mlirIdentifierStr(self.namedAttr.name).length);
2619           },
2620           "The name of the NamedAttribute binding")
2621       .def_property_readonly(
2622           "attr",
2623           [](PyNamedAttribute &self) {
2624             // TODO: When named attribute is removed/refactored, also remove
2625             // this constructor (it does an inefficient table lookup).
2626             auto contextRef = PyMlirContext::forContext(
2627                 mlirAttributeGetContext(self.namedAttr.attribute));
2628             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2629           },
2630           py::keep_alive<0, 1>(),
2631           "The underlying generic attribute of the NamedAttribute binding");
2632 
2633   //----------------------------------------------------------------------------
2634   // Mapping of PyType.
2635   //----------------------------------------------------------------------------
2636   py::class_<PyType>(m, "Type", py::module_local())
2637       // Delegate to the PyType copy constructor, which will also lifetime
2638       // extend the backing context which owns the MlirType.
2639       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2640            "Casts the passed type to the generic Type")
2641       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2642       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2643       .def_static(
2644           "parse",
2645           [](std::string typeSpec, DefaultingPyMlirContext context) {
2646             MlirType type =
2647                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2648             // TODO: Rework error reporting once diagnostic engine is exposed
2649             // in C API.
2650             if (mlirTypeIsNull(type)) {
2651               throw SetPyError(PyExc_ValueError,
2652                                Twine("Unable to parse type: '") + typeSpec +
2653                                    "'");
2654             }
2655             return PyType(context->getRef(), type);
2656           },
2657           py::arg("asm"), py::arg("context") = py::none(),
2658           kContextParseTypeDocstring)
2659       .def_property_readonly(
2660           "context", [](PyType &self) { return self.getContext().getObject(); },
2661           "Context that owns the Type")
2662       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2663       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2664       .def("__hash__",
2665            [](PyType &self) {
2666              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2667            })
2668       .def(
2669           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2670       .def(
2671           "__str__",
2672           [](PyType &self) {
2673             PyPrintAccumulator printAccum;
2674             mlirTypePrint(self, printAccum.getCallback(),
2675                           printAccum.getUserData());
2676             return printAccum.join();
2677           },
2678           "Returns the assembly form of the type.")
2679       .def("__repr__", [](PyType &self) {
2680         // Generally, assembly formats are not printed for __repr__ because
2681         // this can cause exceptionally long debug output and exceptions.
2682         // However, types are an exception as they typically have compact
2683         // assembly forms and printing them is useful.
2684         PyPrintAccumulator printAccum;
2685         printAccum.parts.append("Type(");
2686         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2687         printAccum.parts.append(")");
2688         return printAccum.join();
2689       });
2690 
2691   //----------------------------------------------------------------------------
2692   // Mapping of Value.
2693   //----------------------------------------------------------------------------
2694   py::class_<PyValue>(m, "Value", py::module_local())
2695       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2696       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2697       .def_property_readonly(
2698           "context",
2699           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2700           "Context in which the value lives.")
2701       .def(
2702           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2703           kDumpDocstring)
2704       .def_property_readonly(
2705           "owner",
2706           [](PyValue &self) {
2707             assert(mlirOperationEqual(self.getParentOperation()->get(),
2708                                       mlirOpResultGetOwner(self.get())) &&
2709                    "expected the owner of the value in Python to match that in "
2710                    "the IR");
2711             return self.getParentOperation().getObject();
2712           })
2713       .def("__eq__",
2714            [](PyValue &self, PyValue &other) {
2715              return self.get().ptr == other.get().ptr;
2716            })
2717       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2718       .def("__hash__",
2719            [](PyValue &self) {
2720              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2721            })
2722       .def(
2723           "__str__",
2724           [](PyValue &self) {
2725             PyPrintAccumulator printAccum;
2726             printAccum.parts.append("Value(");
2727             mlirValuePrint(self.get(), printAccum.getCallback(),
2728                            printAccum.getUserData());
2729             printAccum.parts.append(")");
2730             return printAccum.join();
2731           },
2732           kValueDunderStrDocstring)
2733       .def_property_readonly("type", [](PyValue &self) {
2734         return PyType(self.getParentOperation()->getContext(),
2735                       mlirValueGetType(self.get()));
2736       });
2737   PyBlockArgument::bind(m);
2738   PyOpResult::bind(m);
2739 
2740   //----------------------------------------------------------------------------
2741   // Mapping of SymbolTable.
2742   //----------------------------------------------------------------------------
2743   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2744       .def(py::init<PyOperationBase &>())
2745       .def("__getitem__", &PySymbolTable::dunderGetItem)
2746       .def("insert", &PySymbolTable::insert)
2747       .def("erase", &PySymbolTable::erase)
2748       .def("__delitem__", &PySymbolTable::dunderDel)
2749       .def("__contains__", [](PySymbolTable &table, const std::string &name) {
2750         return !mlirOperationIsNull(mlirSymbolTableLookup(
2751             table, mlirStringRefCreate(name.data(), name.length())));
2752       });
2753 
2754   // Container bindings.
2755   PyBlockArgumentList::bind(m);
2756   PyBlockIterator::bind(m);
2757   PyBlockList::bind(m);
2758   PyOperationIterator::bind(m);
2759   PyOperationList::bind(m);
2760   PyOpAttributeMap::bind(m);
2761   PyOpOperandList::bind(m);
2762   PyOpResultList::bind(m);
2763   PyRegionIterator::bind(m);
2764   PyRegionList::bind(m);
2765 
2766   // Debug bindings.
2767   PyGlobalDebugFlag::bind(m);
2768 }
2769