1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96 )";
97 
98 static const char kOperationGetAsmDocstring[] =
99     R"(Gets the assembly form of the operation with all options available.
100 
101 Args:
102   binary: Whether to return a bytes (True) or str (False) object. Defaults to
103     False.
104   ... others ...: See the print() method for common keyword arguments for
105     configuring the printout.
106 Returns:
107   Either a bytes or str object, depending on the setting of the 'binary'
108   argument.
109 )";
110 
111 static const char kOperationStrDunderDocstring[] =
112     R"(Gets the assembly form of the operation with default options.
113 
114 If more advanced control over the assembly formatting or I/O options is needed,
115 use the dedicated print or get_asm method, which supports keyword arguments to
116 customize behavior.
117 )";
118 
119 static const char kDumpDocstring[] =
120     R"(Dumps a debug representation of the object to stderr.)";
121 
122 static const char kAppendBlockDocstring[] =
123     R"(Appends a new block, with argument types as positional args.
124 
125 Returns:
126   The created block.
127 )";
128 
129 static const char kValueDunderStrDocstring[] =
130     R"(Returns the string form of the value.
131 
132 If the value is a block argument, this is the assembly form of its type and the
133 position in the argument list. If the value is an operation result, this is
134 equivalent to printing the operation that produced it.
135 )";
136 
137 //------------------------------------------------------------------------------
138 // Utilities.
139 //------------------------------------------------------------------------------
140 
141 /// Helper for creating an @classmethod.
142 template <class Func, typename... Args>
143 py::object classmethod(Func f, Args... args) {
144   py::object cf = py::cpp_function(f, args...);
145   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
146 }
147 
148 static py::object
149 createCustomDialectWrapper(const std::string &dialectNamespace,
150                            py::object dialectDescriptor) {
151   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
152   if (!dialectClass) {
153     // Use the base class.
154     return py::cast(PyDialect(std::move(dialectDescriptor)));
155   }
156 
157   // Create the custom implementation.
158   return (*dialectClass)(std::move(dialectDescriptor));
159 }
160 
161 static MlirStringRef toMlirStringRef(const std::string &s) {
162   return mlirStringRefCreate(s.data(), s.size());
163 }
164 
165 /// Wrapper for the global LLVM debugging flag.
166 struct PyGlobalDebugFlag {
167   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
168 
169   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
170 
171   static void bind(py::module &m) {
172     // Debug flags.
173     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
174         .def_property_static("flag", &PyGlobalDebugFlag::get,
175                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
176   }
177 };
178 
179 //------------------------------------------------------------------------------
180 // Collections.
181 //------------------------------------------------------------------------------
182 
183 namespace {
184 
185 class PyRegionIterator {
186 public:
187   PyRegionIterator(PyOperationRef operation)
188       : operation(std::move(operation)) {}
189 
190   PyRegionIterator &dunderIter() { return *this; }
191 
192   PyRegion dunderNext() {
193     operation->checkValid();
194     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
195       throw py::stop_iteration();
196     }
197     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
198     return PyRegion(operation, region);
199   }
200 
201   static void bind(py::module &m) {
202     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
203         .def("__iter__", &PyRegionIterator::dunderIter)
204         .def("__next__", &PyRegionIterator::dunderNext);
205   }
206 
207 private:
208   PyOperationRef operation;
209   int nextIndex = 0;
210 };
211 
212 /// Regions of an op are fixed length and indexed numerically so are represented
213 /// with a sequence-like container.
214 class PyRegionList {
215 public:
216   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
217 
218   intptr_t dunderLen() {
219     operation->checkValid();
220     return mlirOperationGetNumRegions(operation->get());
221   }
222 
223   PyRegion dunderGetItem(intptr_t index) {
224     // dunderLen checks validity.
225     if (index < 0 || index >= dunderLen()) {
226       throw SetPyError(PyExc_IndexError,
227                        "attempt to access out of bounds region");
228     }
229     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
230     return PyRegion(operation, region);
231   }
232 
233   static void bind(py::module &m) {
234     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
235         .def("__len__", &PyRegionList::dunderLen)
236         .def("__getitem__", &PyRegionList::dunderGetItem);
237   }
238 
239 private:
240   PyOperationRef operation;
241 };
242 
243 class PyBlockIterator {
244 public:
245   PyBlockIterator(PyOperationRef operation, MlirBlock next)
246       : operation(std::move(operation)), next(next) {}
247 
248   PyBlockIterator &dunderIter() { return *this; }
249 
250   PyBlock dunderNext() {
251     operation->checkValid();
252     if (mlirBlockIsNull(next)) {
253       throw py::stop_iteration();
254     }
255 
256     PyBlock returnBlock(operation, next);
257     next = mlirBlockGetNextInRegion(next);
258     return returnBlock;
259   }
260 
261   static void bind(py::module &m) {
262     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
263         .def("__iter__", &PyBlockIterator::dunderIter)
264         .def("__next__", &PyBlockIterator::dunderNext);
265   }
266 
267 private:
268   PyOperationRef operation;
269   MlirBlock next;
270 };
271 
272 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
273 /// we present them as a more full-featured list-like container but optimize
274 /// it for forward iteration. Blocks are always owned by a region.
275 class PyBlockList {
276 public:
277   PyBlockList(PyOperationRef operation, MlirRegion region)
278       : operation(std::move(operation)), region(region) {}
279 
280   PyBlockIterator dunderIter() {
281     operation->checkValid();
282     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
283   }
284 
285   intptr_t dunderLen() {
286     operation->checkValid();
287     intptr_t count = 0;
288     MlirBlock block = mlirRegionGetFirstBlock(region);
289     while (!mlirBlockIsNull(block)) {
290       count += 1;
291       block = mlirBlockGetNextInRegion(block);
292     }
293     return count;
294   }
295 
296   PyBlock dunderGetItem(intptr_t index) {
297     operation->checkValid();
298     if (index < 0) {
299       throw SetPyError(PyExc_IndexError,
300                        "attempt to access out of bounds block");
301     }
302     MlirBlock block = mlirRegionGetFirstBlock(region);
303     while (!mlirBlockIsNull(block)) {
304       if (index == 0) {
305         return PyBlock(operation, block);
306       }
307       block = mlirBlockGetNextInRegion(block);
308       index -= 1;
309     }
310     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
311   }
312 
313   PyBlock appendBlock(py::args pyArgTypes) {
314     operation->checkValid();
315     llvm::SmallVector<MlirType, 4> argTypes;
316     argTypes.reserve(pyArgTypes.size());
317     for (auto &pyArg : pyArgTypes) {
318       argTypes.push_back(pyArg.cast<PyType &>());
319     }
320 
321     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
322     mlirRegionAppendOwnedBlock(region, block);
323     return PyBlock(operation, block);
324   }
325 
326   static void bind(py::module &m) {
327     py::class_<PyBlockList>(m, "BlockList", py::module_local())
328         .def("__getitem__", &PyBlockList::dunderGetItem)
329         .def("__iter__", &PyBlockList::dunderIter)
330         .def("__len__", &PyBlockList::dunderLen)
331         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
332   }
333 
334 private:
335   PyOperationRef operation;
336   MlirRegion region;
337 };
338 
339 class PyOperationIterator {
340 public:
341   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
342       : parentOperation(std::move(parentOperation)), next(next) {}
343 
344   PyOperationIterator &dunderIter() { return *this; }
345 
346   py::object dunderNext() {
347     parentOperation->checkValid();
348     if (mlirOperationIsNull(next)) {
349       throw py::stop_iteration();
350     }
351 
352     PyOperationRef returnOperation =
353         PyOperation::forOperation(parentOperation->getContext(), next);
354     next = mlirOperationGetNextInBlock(next);
355     return returnOperation->createOpView();
356   }
357 
358   static void bind(py::module &m) {
359     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
360         .def("__iter__", &PyOperationIterator::dunderIter)
361         .def("__next__", &PyOperationIterator::dunderNext);
362   }
363 
364 private:
365   PyOperationRef parentOperation;
366   MlirOperation next;
367 };
368 
369 /// Operations are exposed by the C-API as a forward-only linked list. In
370 /// Python, we present them as a more full-featured list-like container but
371 /// optimize it for forward iteration. Iterable operations are always owned
372 /// by a block.
373 class PyOperationList {
374 public:
375   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
376       : parentOperation(std::move(parentOperation)), block(block) {}
377 
378   PyOperationIterator dunderIter() {
379     parentOperation->checkValid();
380     return PyOperationIterator(parentOperation,
381                                mlirBlockGetFirstOperation(block));
382   }
383 
384   intptr_t dunderLen() {
385     parentOperation->checkValid();
386     intptr_t count = 0;
387     MlirOperation childOp = mlirBlockGetFirstOperation(block);
388     while (!mlirOperationIsNull(childOp)) {
389       count += 1;
390       childOp = mlirOperationGetNextInBlock(childOp);
391     }
392     return count;
393   }
394 
395   py::object dunderGetItem(intptr_t index) {
396     parentOperation->checkValid();
397     if (index < 0) {
398       throw SetPyError(PyExc_IndexError,
399                        "attempt to access out of bounds operation");
400     }
401     MlirOperation childOp = mlirBlockGetFirstOperation(block);
402     while (!mlirOperationIsNull(childOp)) {
403       if (index == 0) {
404         return PyOperation::forOperation(parentOperation->getContext(), childOp)
405             ->createOpView();
406       }
407       childOp = mlirOperationGetNextInBlock(childOp);
408       index -= 1;
409     }
410     throw SetPyError(PyExc_IndexError,
411                      "attempt to access out of bounds operation");
412   }
413 
414   static void bind(py::module &m) {
415     py::class_<PyOperationList>(m, "OperationList", py::module_local())
416         .def("__getitem__", &PyOperationList::dunderGetItem)
417         .def("__iter__", &PyOperationList::dunderIter)
418         .def("__len__", &PyOperationList::dunderLen);
419   }
420 
421 private:
422   PyOperationRef parentOperation;
423   MlirBlock block;
424 };
425 
426 } // namespace
427 
428 //------------------------------------------------------------------------------
429 // PyMlirContext
430 //------------------------------------------------------------------------------
431 
432 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
433   py::gil_scoped_acquire acquire;
434   auto &liveContexts = getLiveContexts();
435   liveContexts[context.ptr] = this;
436 }
437 
438 PyMlirContext::~PyMlirContext() {
439   // Note that the only public way to construct an instance is via the
440   // forContext method, which always puts the associated handle into
441   // liveContexts.
442   py::gil_scoped_acquire acquire;
443   getLiveContexts().erase(context.ptr);
444   mlirContextDestroy(context);
445 }
446 
447 py::object PyMlirContext::getCapsule() {
448   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
449 }
450 
451 py::object PyMlirContext::createFromCapsule(py::object capsule) {
452   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
453   if (mlirContextIsNull(rawContext))
454     throw py::error_already_set();
455   return forContext(rawContext).releaseObject();
456 }
457 
458 PyMlirContext *PyMlirContext::createNewContextForInit() {
459   MlirContext context = mlirContextCreate();
460   mlirRegisterAllDialects(context);
461   return new PyMlirContext(context);
462 }
463 
464 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
465   py::gil_scoped_acquire acquire;
466   auto &liveContexts = getLiveContexts();
467   auto it = liveContexts.find(context.ptr);
468   if (it == liveContexts.end()) {
469     // Create.
470     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
471     py::object pyRef = py::cast(unownedContextWrapper);
472     assert(pyRef && "cast to py::object failed");
473     liveContexts[context.ptr] = unownedContextWrapper;
474     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
475   }
476   // Use existing.
477   py::object pyRef = py::cast(it->second);
478   return PyMlirContextRef(it->second, std::move(pyRef));
479 }
480 
481 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
482   static LiveContextMap liveContexts;
483   return liveContexts;
484 }
485 
486 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
487 
488 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
489 
490 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
491 
492 pybind11::object PyMlirContext::contextEnter() {
493   return PyThreadContextEntry::pushContext(*this);
494 }
495 
496 void PyMlirContext::contextExit(pybind11::object excType,
497                                 pybind11::object excVal,
498                                 pybind11::object excTb) {
499   PyThreadContextEntry::popContext(*this);
500 }
501 
502 PyMlirContext &DefaultingPyMlirContext::resolve() {
503   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
504   if (!context) {
505     throw SetPyError(
506         PyExc_RuntimeError,
507         "An MLIR function requires a Context but none was provided in the call "
508         "or from the surrounding environment. Either pass to the function with "
509         "a 'context=' argument or establish a default using 'with Context():'");
510   }
511   return *context;
512 }
513 
514 //------------------------------------------------------------------------------
515 // PyThreadContextEntry management
516 //------------------------------------------------------------------------------
517 
518 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
519   static thread_local std::vector<PyThreadContextEntry> stack;
520   return stack;
521 }
522 
523 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
524   auto &stack = getStack();
525   if (stack.empty())
526     return nullptr;
527   return &stack.back();
528 }
529 
530 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
531                                 py::object insertionPoint,
532                                 py::object location) {
533   auto &stack = getStack();
534   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
535                      std::move(location));
536   // If the new stack has more than one entry and the context of the new top
537   // entry matches the previous, copy the insertionPoint and location from the
538   // previous entry if missing from the new top entry.
539   if (stack.size() > 1) {
540     auto &prev = *(stack.rbegin() + 1);
541     auto &current = stack.back();
542     if (current.context.is(prev.context)) {
543       // Default non-context objects from the previous entry.
544       if (!current.insertionPoint)
545         current.insertionPoint = prev.insertionPoint;
546       if (!current.location)
547         current.location = prev.location;
548     }
549   }
550 }
551 
552 PyMlirContext *PyThreadContextEntry::getContext() {
553   if (!context)
554     return nullptr;
555   return py::cast<PyMlirContext *>(context);
556 }
557 
558 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
559   if (!insertionPoint)
560     return nullptr;
561   return py::cast<PyInsertionPoint *>(insertionPoint);
562 }
563 
564 PyLocation *PyThreadContextEntry::getLocation() {
565   if (!location)
566     return nullptr;
567   return py::cast<PyLocation *>(location);
568 }
569 
570 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
571   auto *tos = getTopOfStack();
572   return tos ? tos->getContext() : nullptr;
573 }
574 
575 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
576   auto *tos = getTopOfStack();
577   return tos ? tos->getInsertionPoint() : nullptr;
578 }
579 
580 PyLocation *PyThreadContextEntry::getDefaultLocation() {
581   auto *tos = getTopOfStack();
582   return tos ? tos->getLocation() : nullptr;
583 }
584 
585 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
586   py::object contextObj = py::cast(context);
587   push(FrameKind::Context, /*context=*/contextObj,
588        /*insertionPoint=*/py::object(),
589        /*location=*/py::object());
590   return contextObj;
591 }
592 
593 void PyThreadContextEntry::popContext(PyMlirContext &context) {
594   auto &stack = getStack();
595   if (stack.empty())
596     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
599     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
600   stack.pop_back();
601 }
602 
603 py::object
604 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
605   py::object contextObj =
606       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
607   py::object insertionPointObj = py::cast(insertionPoint);
608   push(FrameKind::InsertionPoint,
609        /*context=*/contextObj,
610        /*insertionPoint=*/insertionPointObj,
611        /*location=*/py::object());
612   return insertionPointObj;
613 }
614 
615 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
616   auto &stack = getStack();
617   if (stack.empty())
618     throw SetPyError(PyExc_RuntimeError,
619                      "Unbalanced InsertionPoint enter/exit");
620   auto &tos = stack.back();
621   if (tos.frameKind != FrameKind::InsertionPoint &&
622       tos.getInsertionPoint() != &insertionPoint)
623     throw SetPyError(PyExc_RuntimeError,
624                      "Unbalanced InsertionPoint enter/exit");
625   stack.pop_back();
626 }
627 
628 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
629   py::object contextObj = location.getContext().getObject();
630   py::object locationObj = py::cast(location);
631   push(FrameKind::Location, /*context=*/contextObj,
632        /*insertionPoint=*/py::object(),
633        /*location=*/locationObj);
634   return locationObj;
635 }
636 
637 void PyThreadContextEntry::popLocation(PyLocation &location) {
638   auto &stack = getStack();
639   if (stack.empty())
640     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
641   auto &tos = stack.back();
642   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
643     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
644   stack.pop_back();
645 }
646 
647 //------------------------------------------------------------------------------
648 // PyDialect, PyDialectDescriptor, PyDialects
649 //------------------------------------------------------------------------------
650 
651 MlirDialect PyDialects::getDialectForKey(const std::string &key,
652                                          bool attrError) {
653   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
654                                                     {key.data(), key.size()});
655   if (mlirDialectIsNull(dialect)) {
656     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
657                      Twine("Dialect '") + key + "' not found");
658   }
659   return dialect;
660 }
661 
662 //------------------------------------------------------------------------------
663 // PyLocation
664 //------------------------------------------------------------------------------
665 
666 py::object PyLocation::getCapsule() {
667   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
668 }
669 
670 PyLocation PyLocation::createFromCapsule(py::object capsule) {
671   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
672   if (mlirLocationIsNull(rawLoc))
673     throw py::error_already_set();
674   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
675                     rawLoc);
676 }
677 
678 py::object PyLocation::contextEnter() {
679   return PyThreadContextEntry::pushLocation(*this);
680 }
681 
682 void PyLocation::contextExit(py::object excType, py::object excVal,
683                              py::object excTb) {
684   PyThreadContextEntry::popLocation(*this);
685 }
686 
687 PyLocation &DefaultingPyLocation::resolve() {
688   auto *location = PyThreadContextEntry::getDefaultLocation();
689   if (!location) {
690     throw SetPyError(
691         PyExc_RuntimeError,
692         "An MLIR function requires a Location but none was provided in the "
693         "call or from the surrounding environment. Either pass to the function "
694         "with a 'loc=' argument or establish a default using 'with loc:'");
695   }
696   return *location;
697 }
698 
699 //------------------------------------------------------------------------------
700 // PyModule
701 //------------------------------------------------------------------------------
702 
703 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
704     : BaseContextObject(std::move(contextRef)), module(module) {}
705 
706 PyModule::~PyModule() {
707   py::gil_scoped_acquire acquire;
708   auto &liveModules = getContext()->liveModules;
709   assert(liveModules.count(module.ptr) == 1 &&
710          "destroying module not in live map");
711   liveModules.erase(module.ptr);
712   mlirModuleDestroy(module);
713 }
714 
715 PyModuleRef PyModule::forModule(MlirModule module) {
716   MlirContext context = mlirModuleGetContext(module);
717   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
718 
719   py::gil_scoped_acquire acquire;
720   auto &liveModules = contextRef->liveModules;
721   auto it = liveModules.find(module.ptr);
722   if (it == liveModules.end()) {
723     // Create.
724     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
725     // Note that the default return value policy on cast is automatic_reference,
726     // which does not take ownership (delete will not be called).
727     // Just be explicit.
728     py::object pyRef =
729         py::cast(unownedModule, py::return_value_policy::take_ownership);
730     unownedModule->handle = pyRef;
731     liveModules[module.ptr] =
732         std::make_pair(unownedModule->handle, unownedModule);
733     return PyModuleRef(unownedModule, std::move(pyRef));
734   }
735   // Use existing.
736   PyModule *existing = it->second.second;
737   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
738   return PyModuleRef(existing, std::move(pyRef));
739 }
740 
741 py::object PyModule::createFromCapsule(py::object capsule) {
742   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
743   if (mlirModuleIsNull(rawModule))
744     throw py::error_already_set();
745   return forModule(rawModule).releaseObject();
746 }
747 
748 py::object PyModule::getCapsule() {
749   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
750 }
751 
752 //------------------------------------------------------------------------------
753 // PyOperation
754 //------------------------------------------------------------------------------
755 
756 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
757     : BaseContextObject(std::move(contextRef)), operation(operation) {}
758 
759 PyOperation::~PyOperation() {
760   // If the operation has already been invalidated there is nothing to do.
761   if (!valid)
762     return;
763   auto &liveOperations = getContext()->liveOperations;
764   assert(liveOperations.count(operation.ptr) == 1 &&
765          "destroying operation not in live map");
766   liveOperations.erase(operation.ptr);
767   if (!isAttached()) {
768     mlirOperationDestroy(operation);
769   }
770 }
771 
772 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
773                                            MlirOperation operation,
774                                            py::object parentKeepAlive) {
775   auto &liveOperations = contextRef->liveOperations;
776   // Create.
777   PyOperation *unownedOperation =
778       new PyOperation(std::move(contextRef), operation);
779   // Note that the default return value policy on cast is automatic_reference,
780   // which does not take ownership (delete will not be called).
781   // Just be explicit.
782   py::object pyRef =
783       py::cast(unownedOperation, py::return_value_policy::take_ownership);
784   unownedOperation->handle = pyRef;
785   if (parentKeepAlive) {
786     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
787   }
788   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
789   return PyOperationRef(unownedOperation, std::move(pyRef));
790 }
791 
792 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
793                                          MlirOperation operation,
794                                          py::object parentKeepAlive) {
795   auto &liveOperations = contextRef->liveOperations;
796   auto it = liveOperations.find(operation.ptr);
797   if (it == liveOperations.end()) {
798     // Create.
799     return createInstance(std::move(contextRef), operation,
800                           std::move(parentKeepAlive));
801   }
802   // Use existing.
803   PyOperation *existing = it->second.second;
804   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
805   return PyOperationRef(existing, std::move(pyRef));
806 }
807 
808 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
809                                            MlirOperation operation,
810                                            py::object parentKeepAlive) {
811   auto &liveOperations = contextRef->liveOperations;
812   assert(liveOperations.count(operation.ptr) == 0 &&
813          "cannot create detached operation that already exists");
814   (void)liveOperations;
815 
816   PyOperationRef created = createInstance(std::move(contextRef), operation,
817                                           std::move(parentKeepAlive));
818   created->attached = false;
819   return created;
820 }
821 
822 void PyOperation::checkValid() const {
823   if (!valid) {
824     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
825   }
826 }
827 
828 void PyOperationBase::print(py::object fileObject, bool binary,
829                             llvm::Optional<int64_t> largeElementsLimit,
830                             bool enableDebugInfo, bool prettyDebugInfo,
831                             bool printGenericOpForm, bool useLocalScope) {
832   PyOperation &operation = getOperation();
833   operation.checkValid();
834   if (fileObject.is_none())
835     fileObject = py::module::import("sys").attr("stdout");
836 
837   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
838     fileObject.attr("write")("// Verification failed, printing generic form\n");
839     printGenericOpForm = true;
840   }
841 
842   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
843   if (largeElementsLimit)
844     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
845   if (enableDebugInfo)
846     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
847   if (printGenericOpForm)
848     mlirOpPrintingFlagsPrintGenericOpForm(flags);
849 
850   PyFileAccumulator accum(fileObject, binary);
851   py::gil_scoped_release();
852   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
853                               accum.getUserData());
854   mlirOpPrintingFlagsDestroy(flags);
855 }
856 
857 py::object PyOperationBase::getAsm(bool binary,
858                                    llvm::Optional<int64_t> largeElementsLimit,
859                                    bool enableDebugInfo, bool prettyDebugInfo,
860                                    bool printGenericOpForm,
861                                    bool useLocalScope) {
862   py::object fileObject;
863   if (binary) {
864     fileObject = py::module::import("io").attr("BytesIO")();
865   } else {
866     fileObject = py::module::import("io").attr("StringIO")();
867   }
868   print(fileObject, /*binary=*/binary,
869         /*largeElementsLimit=*/largeElementsLimit,
870         /*enableDebugInfo=*/enableDebugInfo,
871         /*prettyDebugInfo=*/prettyDebugInfo,
872         /*printGenericOpForm=*/printGenericOpForm,
873         /*useLocalScope=*/useLocalScope);
874 
875   return fileObject.attr("getvalue")();
876 }
877 
878 void PyOperationBase::moveAfter(PyOperationBase &other) {
879   PyOperation &operation = getOperation();
880   PyOperation &otherOp = other.getOperation();
881   operation.checkValid();
882   otherOp.checkValid();
883   mlirOperationMoveAfter(operation, otherOp);
884   operation.parentKeepAlive = otherOp.parentKeepAlive;
885 }
886 
887 void PyOperationBase::moveBefore(PyOperationBase &other) {
888   PyOperation &operation = getOperation();
889   PyOperation &otherOp = other.getOperation();
890   operation.checkValid();
891   otherOp.checkValid();
892   mlirOperationMoveBefore(operation, otherOp);
893   operation.parentKeepAlive = otherOp.parentKeepAlive;
894 }
895 
896 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
897   checkValid();
898   if (!isAttached())
899     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
900   MlirOperation operation = mlirOperationGetParentOperation(get());
901   if (mlirOperationIsNull(operation))
902     return {};
903   return PyOperation::forOperation(getContext(), operation);
904 }
905 
906 PyBlock PyOperation::getBlock() {
907   checkValid();
908   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
909   MlirBlock block = mlirOperationGetBlock(get());
910   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
911   assert(parentOperation && "Operation has no parent");
912   return PyBlock{std::move(*parentOperation), block};
913 }
914 
915 py::object PyOperation::getCapsule() {
916   checkValid();
917   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
918 }
919 
920 py::object PyOperation::createFromCapsule(py::object capsule) {
921   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
922   if (mlirOperationIsNull(rawOperation))
923     throw py::error_already_set();
924   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
925   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
926       .releaseObject();
927 }
928 
929 py::object PyOperation::create(
930     std::string name, llvm::Optional<std::vector<PyType *>> results,
931     llvm::Optional<std::vector<PyValue *>> operands,
932     llvm::Optional<py::dict> attributes,
933     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
934     DefaultingPyLocation location, py::object maybeIp) {
935   llvm::SmallVector<MlirValue, 4> mlirOperands;
936   llvm::SmallVector<MlirType, 4> mlirResults;
937   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
938   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
939 
940   // General parameter validation.
941   if (regions < 0)
942     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
943 
944   // Unpack/validate operands.
945   if (operands) {
946     mlirOperands.reserve(operands->size());
947     for (PyValue *operand : *operands) {
948       if (!operand)
949         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
950       mlirOperands.push_back(operand->get());
951     }
952   }
953 
954   // Unpack/validate results.
955   if (results) {
956     mlirResults.reserve(results->size());
957     for (PyType *result : *results) {
958       // TODO: Verify result type originate from the same context.
959       if (!result)
960         throw SetPyError(PyExc_ValueError, "result type cannot be None");
961       mlirResults.push_back(*result);
962     }
963   }
964   // Unpack/validate attributes.
965   if (attributes) {
966     mlirAttributes.reserve(attributes->size());
967     for (auto &it : *attributes) {
968       std::string key;
969       try {
970         key = it.first.cast<std::string>();
971       } catch (py::cast_error &err) {
972         std::string msg = "Invalid attribute key (not a string) when "
973                           "attempting to create the operation \"" +
974                           name + "\" (" + err.what() + ")";
975         throw py::cast_error(msg);
976       }
977       try {
978         auto &attribute = it.second.cast<PyAttribute &>();
979         // TODO: Verify attribute originates from the same context.
980         mlirAttributes.emplace_back(std::move(key), attribute);
981       } catch (py::reference_cast_error &) {
982         // This exception seems thrown when the value is "None".
983         std::string msg =
984             "Found an invalid (`None`?) attribute value for the key \"" + key +
985             "\" when attempting to create the operation \"" + name + "\"";
986         throw py::cast_error(msg);
987       } catch (py::cast_error &err) {
988         std::string msg = "Invalid attribute value for the key \"" + key +
989                           "\" when attempting to create the operation \"" +
990                           name + "\" (" + err.what() + ")";
991         throw py::cast_error(msg);
992       }
993     }
994   }
995   // Unpack/validate successors.
996   if (successors) {
997     mlirSuccessors.reserve(successors->size());
998     for (auto *successor : *successors) {
999       // TODO: Verify successor originate from the same context.
1000       if (!successor)
1001         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1002       mlirSuccessors.push_back(successor->get());
1003     }
1004   }
1005 
1006   // Apply unpacked/validated to the operation state. Beyond this
1007   // point, exceptions cannot be thrown or else the state will leak.
1008   MlirOperationState state =
1009       mlirOperationStateGet(toMlirStringRef(name), location);
1010   if (!mlirOperands.empty())
1011     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1012                                   mlirOperands.data());
1013   if (!mlirResults.empty())
1014     mlirOperationStateAddResults(&state, mlirResults.size(),
1015                                  mlirResults.data());
1016   if (!mlirAttributes.empty()) {
1017     // Note that the attribute names directly reference bytes in
1018     // mlirAttributes, so that vector must not be changed from here
1019     // on.
1020     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1021     mlirNamedAttributes.reserve(mlirAttributes.size());
1022     for (auto &it : mlirAttributes)
1023       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1024           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1025                             toMlirStringRef(it.first)),
1026           it.second));
1027     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1028                                     mlirNamedAttributes.data());
1029   }
1030   if (!mlirSuccessors.empty())
1031     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1032                                     mlirSuccessors.data());
1033   if (regions) {
1034     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1035     mlirRegions.resize(regions);
1036     for (int i = 0; i < regions; ++i)
1037       mlirRegions[i] = mlirRegionCreate();
1038     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1039                                       mlirRegions.data());
1040   }
1041 
1042   // Construct the operation.
1043   MlirOperation operation = mlirOperationCreate(&state);
1044   PyOperationRef created =
1045       PyOperation::createDetached(location->getContext(), operation);
1046 
1047   // InsertPoint active?
1048   if (!maybeIp.is(py::cast(false))) {
1049     PyInsertionPoint *ip;
1050     if (maybeIp.is_none()) {
1051       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1052     } else {
1053       ip = py::cast<PyInsertionPoint *>(maybeIp);
1054     }
1055     if (ip)
1056       ip->insert(*created.get());
1057   }
1058 
1059   return created->createOpView();
1060 }
1061 
1062 py::object PyOperation::createOpView() {
1063   checkValid();
1064   MlirIdentifier ident = mlirOperationGetName(get());
1065   MlirStringRef identStr = mlirIdentifierStr(ident);
1066   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1067       StringRef(identStr.data, identStr.length));
1068   if (opViewClass)
1069     return (*opViewClass)(getRef().getObject());
1070   return py::cast(PyOpView(getRef().getObject()));
1071 }
1072 
1073 void PyOperation::erase() {
1074   checkValid();
1075   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1076   // Python reference to a child operation is live. All children should also
1077   // have their `valid` bit set to false.
1078   auto &liveOperations = getContext()->liveOperations;
1079   if (liveOperations.count(operation.ptr))
1080     liveOperations.erase(operation.ptr);
1081   mlirOperationDestroy(operation);
1082   valid = false;
1083 }
1084 
1085 //------------------------------------------------------------------------------
1086 // PyOpView
1087 //------------------------------------------------------------------------------
1088 
1089 py::object
1090 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1091                        py::list operandList,
1092                        llvm::Optional<py::dict> attributes,
1093                        llvm::Optional<std::vector<PyBlock *>> successors,
1094                        llvm::Optional<int> regions,
1095                        DefaultingPyLocation location, py::object maybeIp) {
1096   PyMlirContextRef context = location->getContext();
1097   // Class level operation construction metadata.
1098   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1099   // Operand and result segment specs are either none, which does no
1100   // variadic unpacking, or a list of ints with segment sizes, where each
1101   // element is either a positive number (typically 1 for a scalar) or -1 to
1102   // indicate that it is derived from the length of the same-indexed operand
1103   // or result (implying that it is a list at that position).
1104   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1105   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1106 
1107   std::vector<uint32_t> operandSegmentLengths;
1108   std::vector<uint32_t> resultSegmentLengths;
1109 
1110   // Validate/determine region count.
1111   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1112   int opMinRegionCount = std::get<0>(opRegionSpec);
1113   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1114   if (!regions) {
1115     regions = opMinRegionCount;
1116   }
1117   if (*regions < opMinRegionCount) {
1118     throw py::value_error(
1119         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1120          llvm::Twine(opMinRegionCount) +
1121          " regions but was built with regions=" + llvm::Twine(*regions))
1122             .str());
1123   }
1124   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1125     throw py::value_error(
1126         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1127          llvm::Twine(opMinRegionCount) +
1128          " regions but was built with regions=" + llvm::Twine(*regions))
1129             .str());
1130   }
1131 
1132   // Unpack results.
1133   std::vector<PyType *> resultTypes;
1134   resultTypes.reserve(resultTypeList.size());
1135   if (resultSegmentSpecObj.is_none()) {
1136     // Non-variadic result unpacking.
1137     for (auto it : llvm::enumerate(resultTypeList)) {
1138       try {
1139         resultTypes.push_back(py::cast<PyType *>(it.value()));
1140         if (!resultTypes.back())
1141           throw py::cast_error();
1142       } catch (py::cast_error &err) {
1143         throw py::value_error((llvm::Twine("Result ") +
1144                                llvm::Twine(it.index()) + " of operation \"" +
1145                                name + "\" must be a Type (" + err.what() + ")")
1146                                   .str());
1147       }
1148     }
1149   } else {
1150     // Sized result unpacking.
1151     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1152     if (resultSegmentSpec.size() != resultTypeList.size()) {
1153       throw py::value_error((llvm::Twine("Operation \"") + name +
1154                              "\" requires " +
1155                              llvm::Twine(resultSegmentSpec.size()) +
1156                              " result segments but was provided " +
1157                              llvm::Twine(resultTypeList.size()))
1158                                 .str());
1159     }
1160     resultSegmentLengths.reserve(resultTypeList.size());
1161     for (auto it :
1162          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1163       int segmentSpec = std::get<1>(it.value());
1164       if (segmentSpec == 1 || segmentSpec == 0) {
1165         // Unpack unary element.
1166         try {
1167           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1168           if (resultType) {
1169             resultTypes.push_back(resultType);
1170             resultSegmentLengths.push_back(1);
1171           } else if (segmentSpec == 0) {
1172             // Allowed to be optional.
1173             resultSegmentLengths.push_back(0);
1174           } else {
1175             throw py::cast_error("was None and result is not optional");
1176           }
1177         } catch (py::cast_error &err) {
1178           throw py::value_error((llvm::Twine("Result ") +
1179                                  llvm::Twine(it.index()) + " of operation \"" +
1180                                  name + "\" must be a Type (" + err.what() +
1181                                  ")")
1182                                     .str());
1183         }
1184       } else if (segmentSpec == -1) {
1185         // Unpack sequence by appending.
1186         try {
1187           if (std::get<0>(it.value()).is_none()) {
1188             // Treat it as an empty list.
1189             resultSegmentLengths.push_back(0);
1190           } else {
1191             // Unpack the list.
1192             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1193             for (py::object segmentItem : segment) {
1194               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1195               if (!resultTypes.back()) {
1196                 throw py::cast_error("contained a None item");
1197               }
1198             }
1199             resultSegmentLengths.push_back(segment.size());
1200           }
1201         } catch (std::exception &err) {
1202           // NOTE: Sloppy to be using a catch-all here, but there are at least
1203           // three different unrelated exceptions that can be thrown in the
1204           // above "casts". Just keep the scope above small and catch them all.
1205           throw py::value_error((llvm::Twine("Result ") +
1206                                  llvm::Twine(it.index()) + " of operation \"" +
1207                                  name + "\" must be a Sequence of Types (" +
1208                                  err.what() + ")")
1209                                     .str());
1210         }
1211       } else {
1212         throw py::value_error("Unexpected segment spec");
1213       }
1214     }
1215   }
1216 
1217   // Unpack operands.
1218   std::vector<PyValue *> operands;
1219   operands.reserve(operands.size());
1220   if (operandSegmentSpecObj.is_none()) {
1221     // Non-sized operand unpacking.
1222     for (auto it : llvm::enumerate(operandList)) {
1223       try {
1224         operands.push_back(py::cast<PyValue *>(it.value()));
1225         if (!operands.back())
1226           throw py::cast_error();
1227       } catch (py::cast_error &err) {
1228         throw py::value_error((llvm::Twine("Operand ") +
1229                                llvm::Twine(it.index()) + " of operation \"" +
1230                                name + "\" must be a Value (" + err.what() + ")")
1231                                   .str());
1232       }
1233     }
1234   } else {
1235     // Sized operand unpacking.
1236     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1237     if (operandSegmentSpec.size() != operandList.size()) {
1238       throw py::value_error((llvm::Twine("Operation \"") + name +
1239                              "\" requires " +
1240                              llvm::Twine(operandSegmentSpec.size()) +
1241                              "operand segments but was provided " +
1242                              llvm::Twine(operandList.size()))
1243                                 .str());
1244     }
1245     operandSegmentLengths.reserve(operandList.size());
1246     for (auto it :
1247          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1248       int segmentSpec = std::get<1>(it.value());
1249       if (segmentSpec == 1 || segmentSpec == 0) {
1250         // Unpack unary element.
1251         try {
1252           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1253           if (operandValue) {
1254             operands.push_back(operandValue);
1255             operandSegmentLengths.push_back(1);
1256           } else if (segmentSpec == 0) {
1257             // Allowed to be optional.
1258             operandSegmentLengths.push_back(0);
1259           } else {
1260             throw py::cast_error("was None and operand is not optional");
1261           }
1262         } catch (py::cast_error &err) {
1263           throw py::value_error((llvm::Twine("Operand ") +
1264                                  llvm::Twine(it.index()) + " of operation \"" +
1265                                  name + "\" must be a Value (" + err.what() +
1266                                  ")")
1267                                     .str());
1268         }
1269       } else if (segmentSpec == -1) {
1270         // Unpack sequence by appending.
1271         try {
1272           if (std::get<0>(it.value()).is_none()) {
1273             // Treat it as an empty list.
1274             operandSegmentLengths.push_back(0);
1275           } else {
1276             // Unpack the list.
1277             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1278             for (py::object segmentItem : segment) {
1279               operands.push_back(py::cast<PyValue *>(segmentItem));
1280               if (!operands.back()) {
1281                 throw py::cast_error("contained a None item");
1282               }
1283             }
1284             operandSegmentLengths.push_back(segment.size());
1285           }
1286         } catch (std::exception &err) {
1287           // NOTE: Sloppy to be using a catch-all here, but there are at least
1288           // three different unrelated exceptions that can be thrown in the
1289           // above "casts". Just keep the scope above small and catch them all.
1290           throw py::value_error((llvm::Twine("Operand ") +
1291                                  llvm::Twine(it.index()) + " of operation \"" +
1292                                  name + "\" must be a Sequence of Values (" +
1293                                  err.what() + ")")
1294                                     .str());
1295         }
1296       } else {
1297         throw py::value_error("Unexpected segment spec");
1298       }
1299     }
1300   }
1301 
1302   // Merge operand/result segment lengths into attributes if needed.
1303   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1304     // Dup.
1305     if (attributes) {
1306       attributes = py::dict(*attributes);
1307     } else {
1308       attributes = py::dict();
1309     }
1310     if (attributes->contains("result_segment_sizes") ||
1311         attributes->contains("operand_segment_sizes")) {
1312       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1313                             "'operand_segment_sizes' attribute is unsupported. "
1314                             "Use Operation.create for such low-level access.");
1315     }
1316 
1317     // Add result_segment_sizes attribute.
1318     if (!resultSegmentLengths.empty()) {
1319       int64_t size = resultSegmentLengths.size();
1320       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1321           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1322           resultSegmentLengths.size(), resultSegmentLengths.data());
1323       (*attributes)["result_segment_sizes"] =
1324           PyAttribute(context, segmentLengthAttr);
1325     }
1326 
1327     // Add operand_segment_sizes attribute.
1328     if (!operandSegmentLengths.empty()) {
1329       int64_t size = operandSegmentLengths.size();
1330       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1331           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1332           operandSegmentLengths.size(), operandSegmentLengths.data());
1333       (*attributes)["operand_segment_sizes"] =
1334           PyAttribute(context, segmentLengthAttr);
1335     }
1336   }
1337 
1338   // Delegate to create.
1339   return PyOperation::create(std::move(name),
1340                              /*results=*/std::move(resultTypes),
1341                              /*operands=*/std::move(operands),
1342                              /*attributes=*/std::move(attributes),
1343                              /*successors=*/std::move(successors),
1344                              /*regions=*/*regions, location, maybeIp);
1345 }
1346 
1347 PyOpView::PyOpView(py::object operationObject)
1348     // Casting through the PyOperationBase base-class and then back to the
1349     // Operation lets us accept any PyOperationBase subclass.
1350     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1351       operationObject(operation.getRef().getObject()) {}
1352 
1353 py::object PyOpView::createRawSubclass(py::object userClass) {
1354   // This is... a little gross. The typical pattern is to have a pure python
1355   // class that extends OpView like:
1356   //   class AddFOp(_cext.ir.OpView):
1357   //     def __init__(self, loc, lhs, rhs):
1358   //       operation = loc.context.create_operation(
1359   //           "addf", lhs, rhs, results=[lhs.type])
1360   //       super().__init__(operation)
1361   //
1362   // I.e. The goal of the user facing type is to provide a nice constructor
1363   // that has complete freedom for the op under construction. This is at odds
1364   // with our other desire to sometimes create this object by just passing an
1365   // operation (to initialize the base class). We could do *arg and **kwargs
1366   // munging to try to make it work, but instead, we synthesize a new class
1367   // on the fly which extends this user class (AddFOp in this example) and
1368   // *give it* the base class's __init__ method, thus bypassing the
1369   // intermediate subclass's __init__ method entirely. While slightly,
1370   // underhanded, this is safe/legal because the type hierarchy has not changed
1371   // (we just added a new leaf) and we aren't mucking around with __new__.
1372   // Typically, this new class will be stored on the original as "_Raw" and will
1373   // be used for casts and other things that need a variant of the class that
1374   // is initialized purely from an operation.
1375   py::object parentMetaclass =
1376       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1377   py::dict attributes;
1378   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1379   // now.
1380   //   auto opViewType = py::type::of<PyOpView>();
1381   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1382   attributes["__init__"] = opViewType.attr("__init__");
1383   py::str origName = userClass.attr("__name__");
1384   py::str newName = py::str("_") + origName;
1385   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1386 }
1387 
1388 //------------------------------------------------------------------------------
1389 // PyInsertionPoint.
1390 //------------------------------------------------------------------------------
1391 
1392 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1393 
1394 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1395     : refOperation(beforeOperationBase.getOperation().getRef()),
1396       block((*refOperation)->getBlock()) {}
1397 
1398 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1399   PyOperation &operation = operationBase.getOperation();
1400   if (operation.isAttached())
1401     throw SetPyError(PyExc_ValueError,
1402                      "Attempt to insert operation that is already attached");
1403   block.getParentOperation()->checkValid();
1404   MlirOperation beforeOp = {nullptr};
1405   if (refOperation) {
1406     // Insert before operation.
1407     (*refOperation)->checkValid();
1408     beforeOp = (*refOperation)->get();
1409   } else {
1410     // Insert at end (before null) is only valid if the block does not
1411     // already end in a known terminator (violating this will cause assertion
1412     // failures later).
1413     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1414       throw py::index_error("Cannot insert operation at the end of a block "
1415                             "that already has a terminator. Did you mean to "
1416                             "use 'InsertionPoint.at_block_terminator(block)' "
1417                             "versus 'InsertionPoint(block)'?");
1418     }
1419   }
1420   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1421   operation.setAttached();
1422 }
1423 
1424 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1425   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1426   if (mlirOperationIsNull(firstOp)) {
1427     // Just insert at end.
1428     return PyInsertionPoint(block);
1429   }
1430 
1431   // Insert before first op.
1432   PyOperationRef firstOpRef = PyOperation::forOperation(
1433       block.getParentOperation()->getContext(), firstOp);
1434   return PyInsertionPoint{block, std::move(firstOpRef)};
1435 }
1436 
1437 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1438   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1439   if (mlirOperationIsNull(terminator))
1440     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1441   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1442       block.getParentOperation()->getContext(), terminator);
1443   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1444 }
1445 
1446 py::object PyInsertionPoint::contextEnter() {
1447   return PyThreadContextEntry::pushInsertionPoint(*this);
1448 }
1449 
1450 void PyInsertionPoint::contextExit(pybind11::object excType,
1451                                    pybind11::object excVal,
1452                                    pybind11::object excTb) {
1453   PyThreadContextEntry::popInsertionPoint(*this);
1454 }
1455 
1456 //------------------------------------------------------------------------------
1457 // PyAttribute.
1458 //------------------------------------------------------------------------------
1459 
1460 bool PyAttribute::operator==(const PyAttribute &other) {
1461   return mlirAttributeEqual(attr, other.attr);
1462 }
1463 
1464 py::object PyAttribute::getCapsule() {
1465   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1466 }
1467 
1468 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1469   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1470   if (mlirAttributeIsNull(rawAttr))
1471     throw py::error_already_set();
1472   return PyAttribute(
1473       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1474 }
1475 
1476 //------------------------------------------------------------------------------
1477 // PyNamedAttribute.
1478 //------------------------------------------------------------------------------
1479 
1480 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1481     : ownedName(new std::string(std::move(ownedName))) {
1482   namedAttr = mlirNamedAttributeGet(
1483       mlirIdentifierGet(mlirAttributeGetContext(attr),
1484                         toMlirStringRef(*this->ownedName)),
1485       attr);
1486 }
1487 
1488 //------------------------------------------------------------------------------
1489 // PyType.
1490 //------------------------------------------------------------------------------
1491 
1492 bool PyType::operator==(const PyType &other) {
1493   return mlirTypeEqual(type, other.type);
1494 }
1495 
1496 py::object PyType::getCapsule() {
1497   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1498 }
1499 
1500 PyType PyType::createFromCapsule(py::object capsule) {
1501   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1502   if (mlirTypeIsNull(rawType))
1503     throw py::error_already_set();
1504   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1505                 rawType);
1506 }
1507 
1508 //------------------------------------------------------------------------------
1509 // PyValue and subclases.
1510 //------------------------------------------------------------------------------
1511 
1512 pybind11::object PyValue::getCapsule() {
1513   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1514 }
1515 
1516 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1517   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1518   if (mlirValueIsNull(value))
1519     throw py::error_already_set();
1520   MlirOperation owner;
1521   if (mlirValueIsAOpResult(value))
1522     owner = mlirOpResultGetOwner(value);
1523   if (mlirValueIsABlockArgument(value))
1524     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1525   if (mlirOperationIsNull(owner))
1526     throw py::error_already_set();
1527   MlirContext ctx = mlirOperationGetContext(owner);
1528   PyOperationRef ownerRef =
1529       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1530   return PyValue(ownerRef, value);
1531 }
1532 
1533 //------------------------------------------------------------------------------
1534 // PySymbolTable.
1535 //------------------------------------------------------------------------------
1536 
1537 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1538     : operation(operation.getOperation().getRef()) {
1539   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1540   if (mlirSymbolTableIsNull(symbolTable)) {
1541     throw py::cast_error("Operation is not a Symbol Table.");
1542   }
1543 }
1544 
1545 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1546   operation->checkValid();
1547   MlirOperation symbol = mlirSymbolTableLookup(
1548       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1549   if (mlirOperationIsNull(symbol))
1550     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1551 
1552   return PyOperation::forOperation(operation->getContext(), symbol,
1553                                    operation.getObject())
1554       ->createOpView();
1555 }
1556 
1557 void PySymbolTable::erase(PyOperationBase &symbol) {
1558   operation->checkValid();
1559   symbol.getOperation().checkValid();
1560   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1561   // The operation is also erased, so we must invalidate it. There may be Python
1562   // references to this operation so we don't want to delete it from the list of
1563   // live operations here.
1564   symbol.getOperation().valid = false;
1565 }
1566 
1567 void PySymbolTable::dunderDel(const std::string &name) {
1568   py::object operation = dunderGetItem(name);
1569   erase(py::cast<PyOperationBase &>(operation));
1570 }
1571 
1572 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1573   operation->checkValid();
1574   symbol.getOperation().checkValid();
1575   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1576       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1577   if (mlirAttributeIsNull(symbolAttr))
1578     throw py::value_error("Expected operation to have a symbol name.");
1579   return PyAttribute(
1580       symbol.getOperation().getContext(),
1581       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1582 }
1583 
1584 namespace {
1585 /// CRTP base class for Python MLIR values that subclass Value and should be
1586 /// castable from it. The value hierarchy is one level deep and is not supposed
1587 /// to accommodate other levels unless core MLIR changes.
1588 template <typename DerivedTy>
1589 class PyConcreteValue : public PyValue {
1590 public:
1591   // Derived classes must define statics for:
1592   //   IsAFunctionTy isaFunction
1593   //   const char *pyClassName
1594   // and redefine bindDerived.
1595   using ClassTy = py::class_<DerivedTy, PyValue>;
1596   using IsAFunctionTy = bool (*)(MlirValue);
1597 
1598   PyConcreteValue() = default;
1599   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1600       : PyValue(operationRef, value) {}
1601   PyConcreteValue(PyValue &orig)
1602       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1603 
1604   /// Attempts to cast the original value to the derived type and throws on
1605   /// type mismatches.
1606   static MlirValue castFrom(PyValue &orig) {
1607     if (!DerivedTy::isaFunction(orig.get())) {
1608       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1609       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1610                                              DerivedTy::pyClassName +
1611                                              " (from " + origRepr + ")");
1612     }
1613     return orig.get();
1614   }
1615 
1616   /// Binds the Python module objects to functions of this class.
1617   static void bind(py::module &m) {
1618     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1619     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1620     cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
1621       return DerivedTy::isaFunction(otherValue);
1622     });
1623     DerivedTy::bindDerived(cls);
1624   }
1625 
1626   /// Implemented by derived classes to add methods to the Python subclass.
1627   static void bindDerived(ClassTy &m) {}
1628 };
1629 
1630 /// Python wrapper for MlirBlockArgument.
1631 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1632 public:
1633   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1634   static constexpr const char *pyClassName = "BlockArgument";
1635   using PyConcreteValue::PyConcreteValue;
1636 
1637   static void bindDerived(ClassTy &c) {
1638     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1639       return PyBlock(self.getParentOperation(),
1640                      mlirBlockArgumentGetOwner(self.get()));
1641     });
1642     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1643       return mlirBlockArgumentGetArgNumber(self.get());
1644     });
1645     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1646       return mlirBlockArgumentSetType(self.get(), type);
1647     });
1648   }
1649 };
1650 
1651 /// Python wrapper for MlirOpResult.
1652 class PyOpResult : public PyConcreteValue<PyOpResult> {
1653 public:
1654   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1655   static constexpr const char *pyClassName = "OpResult";
1656   using PyConcreteValue::PyConcreteValue;
1657 
1658   static void bindDerived(ClassTy &c) {
1659     c.def_property_readonly("owner", [](PyOpResult &self) {
1660       assert(
1661           mlirOperationEqual(self.getParentOperation()->get(),
1662                              mlirOpResultGetOwner(self.get())) &&
1663           "expected the owner of the value in Python to match that in the IR");
1664       return self.getParentOperation().getObject();
1665     });
1666     c.def_property_readonly("result_number", [](PyOpResult &self) {
1667       return mlirOpResultGetResultNumber(self.get());
1668     });
1669   }
1670 };
1671 
1672 /// Returns the list of types of the values held by container.
1673 template <typename Container>
1674 static std::vector<PyType> getValueTypes(Container &container,
1675                                          PyMlirContextRef &context) {
1676   std::vector<PyType> result;
1677   result.reserve(container.getNumElements());
1678   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1679     result.push_back(
1680         PyType(context, mlirValueGetType(container.getElement(i).get())));
1681   }
1682   return result;
1683 }
1684 
1685 /// A list of block arguments. Internally, these are stored as consecutive
1686 /// elements, random access is cheap. The argument list is associated with the
1687 /// operation that contains the block (detached blocks are not allowed in
1688 /// Python bindings) and extends its lifetime.
1689 class PyBlockArgumentList
1690     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1691 public:
1692   static constexpr const char *pyClassName = "BlockArgumentList";
1693 
1694   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1695                       intptr_t startIndex = 0, intptr_t length = -1,
1696                       intptr_t step = 1)
1697       : Sliceable(startIndex,
1698                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1699                   step),
1700         operation(std::move(operation)), block(block) {}
1701 
1702   /// Returns the number of arguments in the list.
1703   intptr_t getNumElements() {
1704     operation->checkValid();
1705     return mlirBlockGetNumArguments(block);
1706   }
1707 
1708   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1709   PyBlockArgument getElement(intptr_t pos) {
1710     MlirValue argument = mlirBlockGetArgument(block, pos);
1711     return PyBlockArgument(operation, argument);
1712   }
1713 
1714   /// Returns a sublist of this list.
1715   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1716                             intptr_t step) {
1717     return PyBlockArgumentList(operation, block, startIndex, length, step);
1718   }
1719 
1720   static void bindDerived(ClassTy &c) {
1721     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1722       return getValueTypes(self, self.operation->getContext());
1723     });
1724   }
1725 
1726 private:
1727   PyOperationRef operation;
1728   MlirBlock block;
1729 };
1730 
1731 /// A list of operation operands. Internally, these are stored as consecutive
1732 /// elements, random access is cheap. The result list is associated with the
1733 /// operation whose results these are, and extends the lifetime of this
1734 /// operation.
1735 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1736 public:
1737   static constexpr const char *pyClassName = "OpOperandList";
1738 
1739   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1740                   intptr_t length = -1, intptr_t step = 1)
1741       : Sliceable(startIndex,
1742                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1743                                : length,
1744                   step),
1745         operation(operation) {}
1746 
1747   intptr_t getNumElements() {
1748     operation->checkValid();
1749     return mlirOperationGetNumOperands(operation->get());
1750   }
1751 
1752   PyValue getElement(intptr_t pos) {
1753     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1754     MlirOperation owner;
1755     if (mlirValueIsAOpResult(operand))
1756       owner = mlirOpResultGetOwner(operand);
1757     else if (mlirValueIsABlockArgument(operand))
1758       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1759     else
1760       assert(false && "Value must be an block arg or op result.");
1761     PyOperationRef pyOwner =
1762         PyOperation::forOperation(operation->getContext(), owner);
1763     return PyValue(pyOwner, operand);
1764   }
1765 
1766   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1767     return PyOpOperandList(operation, startIndex, length, step);
1768   }
1769 
1770   void dunderSetItem(intptr_t index, PyValue value) {
1771     index = wrapIndex(index);
1772     mlirOperationSetOperand(operation->get(), index, value.get());
1773   }
1774 
1775   static void bindDerived(ClassTy &c) {
1776     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1777   }
1778 
1779 private:
1780   PyOperationRef operation;
1781 };
1782 
1783 /// A list of operation results. Internally, these are stored as consecutive
1784 /// elements, random access is cheap. The result list is associated with the
1785 /// operation whose results these are, and extends the lifetime of this
1786 /// operation.
1787 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1788 public:
1789   static constexpr const char *pyClassName = "OpResultList";
1790 
1791   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1792                  intptr_t length = -1, intptr_t step = 1)
1793       : Sliceable(startIndex,
1794                   length == -1 ? mlirOperationGetNumResults(operation->get())
1795                                : length,
1796                   step),
1797         operation(operation) {}
1798 
1799   intptr_t getNumElements() {
1800     operation->checkValid();
1801     return mlirOperationGetNumResults(operation->get());
1802   }
1803 
1804   PyOpResult getElement(intptr_t index) {
1805     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1806     return PyOpResult(value);
1807   }
1808 
1809   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1810     return PyOpResultList(operation, startIndex, length, step);
1811   }
1812 
1813   static void bindDerived(ClassTy &c) {
1814     c.def_property_readonly("types", [](PyOpResultList &self) {
1815       return getValueTypes(self, self.operation->getContext());
1816     });
1817   }
1818 
1819 private:
1820   PyOperationRef operation;
1821 };
1822 
1823 /// A list of operation attributes. Can be indexed by name, producing
1824 /// attributes, or by index, producing named attributes.
1825 class PyOpAttributeMap {
1826 public:
1827   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1828 
1829   PyAttribute dunderGetItemNamed(const std::string &name) {
1830     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1831                                                          toMlirStringRef(name));
1832     if (mlirAttributeIsNull(attr)) {
1833       throw SetPyError(PyExc_KeyError,
1834                        "attempt to access a non-existent attribute");
1835     }
1836     return PyAttribute(operation->getContext(), attr);
1837   }
1838 
1839   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1840     if (index < 0 || index >= dunderLen()) {
1841       throw SetPyError(PyExc_IndexError,
1842                        "attempt to access out of bounds attribute");
1843     }
1844     MlirNamedAttribute namedAttr =
1845         mlirOperationGetAttribute(operation->get(), index);
1846     return PyNamedAttribute(
1847         namedAttr.attribute,
1848         std::string(mlirIdentifierStr(namedAttr.name).data));
1849   }
1850 
1851   void dunderSetItem(const std::string &name, PyAttribute attr) {
1852     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1853                                     attr);
1854   }
1855 
1856   void dunderDelItem(const std::string &name) {
1857     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1858                                                      toMlirStringRef(name));
1859     if (!removed)
1860       throw SetPyError(PyExc_KeyError,
1861                        "attempt to delete a non-existent attribute");
1862   }
1863 
1864   intptr_t dunderLen() {
1865     return mlirOperationGetNumAttributes(operation->get());
1866   }
1867 
1868   bool dunderContains(const std::string &name) {
1869     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1870         operation->get(), toMlirStringRef(name)));
1871   }
1872 
1873   static void bind(py::module &m) {
1874     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1875         .def("__contains__", &PyOpAttributeMap::dunderContains)
1876         .def("__len__", &PyOpAttributeMap::dunderLen)
1877         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1878         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1879         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1880         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1881   }
1882 
1883 private:
1884   PyOperationRef operation;
1885 };
1886 
1887 } // end namespace
1888 
1889 //------------------------------------------------------------------------------
1890 // Populates the core exports of the 'ir' submodule.
1891 //------------------------------------------------------------------------------
1892 
1893 void mlir::python::populateIRCore(py::module &m) {
1894   //----------------------------------------------------------------------------
1895   // Mapping of MlirContext.
1896   //----------------------------------------------------------------------------
1897   py::class_<PyMlirContext>(m, "Context", py::module_local())
1898       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1899       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1900       .def("_get_context_again",
1901            [](PyMlirContext &self) {
1902              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1903              return ref.releaseObject();
1904            })
1905       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1906       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1907       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1908                              &PyMlirContext::getCapsule)
1909       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1910       .def("__enter__", &PyMlirContext::contextEnter)
1911       .def("__exit__", &PyMlirContext::contextExit)
1912       .def_property_readonly_static(
1913           "current",
1914           [](py::object & /*class*/) {
1915             auto *context = PyThreadContextEntry::getDefaultContext();
1916             if (!context)
1917               throw SetPyError(PyExc_ValueError, "No current Context");
1918             return context;
1919           },
1920           "Gets the Context bound to the current thread or raises ValueError")
1921       .def_property_readonly(
1922           "dialects",
1923           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1924           "Gets a container for accessing dialects by name")
1925       .def_property_readonly(
1926           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1927           "Alias for 'dialect'")
1928       .def(
1929           "get_dialect_descriptor",
1930           [=](PyMlirContext &self, std::string &name) {
1931             MlirDialect dialect = mlirContextGetOrLoadDialect(
1932                 self.get(), {name.data(), name.size()});
1933             if (mlirDialectIsNull(dialect)) {
1934               throw SetPyError(PyExc_ValueError,
1935                                Twine("Dialect '") + name + "' not found");
1936             }
1937             return PyDialectDescriptor(self.getRef(), dialect);
1938           },
1939           "Gets or loads a dialect by name, returning its descriptor object")
1940       .def_property(
1941           "allow_unregistered_dialects",
1942           [](PyMlirContext &self) -> bool {
1943             return mlirContextGetAllowUnregisteredDialects(self.get());
1944           },
1945           [](PyMlirContext &self, bool value) {
1946             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1947           })
1948       .def("enable_multithreading",
1949            [](PyMlirContext &self, bool enable) {
1950              mlirContextEnableMultithreading(self.get(), enable);
1951            })
1952       .def("is_registered_operation",
1953            [](PyMlirContext &self, std::string &name) {
1954              return mlirContextIsRegisteredOperation(
1955                  self.get(), MlirStringRef{name.data(), name.size()});
1956            });
1957 
1958   //----------------------------------------------------------------------------
1959   // Mapping of PyDialectDescriptor
1960   //----------------------------------------------------------------------------
1961   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1962       .def_property_readonly("namespace",
1963                              [](PyDialectDescriptor &self) {
1964                                MlirStringRef ns =
1965                                    mlirDialectGetNamespace(self.get());
1966                                return py::str(ns.data, ns.length);
1967                              })
1968       .def("__repr__", [](PyDialectDescriptor &self) {
1969         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1970         std::string repr("<DialectDescriptor ");
1971         repr.append(ns.data, ns.length);
1972         repr.append(">");
1973         return repr;
1974       });
1975 
1976   //----------------------------------------------------------------------------
1977   // Mapping of PyDialects
1978   //----------------------------------------------------------------------------
1979   py::class_<PyDialects>(m, "Dialects", py::module_local())
1980       .def("__getitem__",
1981            [=](PyDialects &self, std::string keyName) {
1982              MlirDialect dialect =
1983                  self.getDialectForKey(keyName, /*attrError=*/false);
1984              py::object descriptor =
1985                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1986              return createCustomDialectWrapper(keyName, std::move(descriptor));
1987            })
1988       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1989         MlirDialect dialect =
1990             self.getDialectForKey(attrName, /*attrError=*/true);
1991         py::object descriptor =
1992             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1993         return createCustomDialectWrapper(attrName, std::move(descriptor));
1994       });
1995 
1996   //----------------------------------------------------------------------------
1997   // Mapping of PyDialect
1998   //----------------------------------------------------------------------------
1999   py::class_<PyDialect>(m, "Dialect", py::module_local())
2000       .def(py::init<py::object>(), "descriptor")
2001       .def_property_readonly(
2002           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2003       .def("__repr__", [](py::object self) {
2004         auto clazz = self.attr("__class__");
2005         return py::str("<Dialect ") +
2006                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2007                clazz.attr("__module__") + py::str(".") +
2008                clazz.attr("__name__") + py::str(")>");
2009       });
2010 
2011   //----------------------------------------------------------------------------
2012   // Mapping of Location
2013   //----------------------------------------------------------------------------
2014   py::class_<PyLocation>(m, "Location", py::module_local())
2015       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2016       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2017       .def("__enter__", &PyLocation::contextEnter)
2018       .def("__exit__", &PyLocation::contextExit)
2019       .def("__eq__",
2020            [](PyLocation &self, PyLocation &other) -> bool {
2021              return mlirLocationEqual(self, other);
2022            })
2023       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2024       .def_property_readonly_static(
2025           "current",
2026           [](py::object & /*class*/) {
2027             auto *loc = PyThreadContextEntry::getDefaultLocation();
2028             if (!loc)
2029               throw SetPyError(PyExc_ValueError, "No current Location");
2030             return loc;
2031           },
2032           "Gets the Location bound to the current thread or raises ValueError")
2033       .def_static(
2034           "unknown",
2035           [](DefaultingPyMlirContext context) {
2036             return PyLocation(context->getRef(),
2037                               mlirLocationUnknownGet(context->get()));
2038           },
2039           py::arg("context") = py::none(),
2040           "Gets a Location representing an unknown location")
2041       .def_static(
2042           "callsite",
2043           [](PyLocation callee, const std::vector<PyLocation> &frames,
2044              DefaultingPyMlirContext context) {
2045             if (frames.empty())
2046               throw py::value_error("No caller frames provided");
2047             MlirLocation caller = frames.back().get();
2048             for (const PyLocation &frame :
2049                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2050               caller = mlirLocationCallSiteGet(frame.get(), caller);
2051             return PyLocation(context->getRef(),
2052                               mlirLocationCallSiteGet(callee.get(), caller));
2053           },
2054           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2055           kContextGetCallSiteLocationDocstring)
2056       .def_static(
2057           "file",
2058           [](std::string filename, int line, int col,
2059              DefaultingPyMlirContext context) {
2060             return PyLocation(
2061                 context->getRef(),
2062                 mlirLocationFileLineColGet(
2063                     context->get(), toMlirStringRef(filename), line, col));
2064           },
2065           py::arg("filename"), py::arg("line"), py::arg("col"),
2066           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2067       .def_static(
2068           "name",
2069           [](std::string name, llvm::Optional<PyLocation> childLoc,
2070              DefaultingPyMlirContext context) {
2071             return PyLocation(
2072                 context->getRef(),
2073                 mlirLocationNameGet(
2074                     context->get(), toMlirStringRef(name),
2075                     childLoc ? childLoc->get()
2076                              : mlirLocationUnknownGet(context->get())));
2077           },
2078           py::arg("name"), py::arg("childLoc") = py::none(),
2079           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2080       .def_property_readonly(
2081           "context",
2082           [](PyLocation &self) { return self.getContext().getObject(); },
2083           "Context that owns the Location")
2084       .def("__repr__", [](PyLocation &self) {
2085         PyPrintAccumulator printAccum;
2086         mlirLocationPrint(self, printAccum.getCallback(),
2087                           printAccum.getUserData());
2088         return printAccum.join();
2089       });
2090 
2091   //----------------------------------------------------------------------------
2092   // Mapping of Module
2093   //----------------------------------------------------------------------------
2094   py::class_<PyModule>(m, "Module", py::module_local())
2095       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2096       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2097       .def_static(
2098           "parse",
2099           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2100             MlirModule module = mlirModuleCreateParse(
2101                 context->get(), toMlirStringRef(moduleAsm));
2102             // TODO: Rework error reporting once diagnostic engine is exposed
2103             // in C API.
2104             if (mlirModuleIsNull(module)) {
2105               throw SetPyError(
2106                   PyExc_ValueError,
2107                   "Unable to parse module assembly (see diagnostics)");
2108             }
2109             return PyModule::forModule(module).releaseObject();
2110           },
2111           py::arg("asm"), py::arg("context") = py::none(),
2112           kModuleParseDocstring)
2113       .def_static(
2114           "create",
2115           [](DefaultingPyLocation loc) {
2116             MlirModule module = mlirModuleCreateEmpty(loc);
2117             return PyModule::forModule(module).releaseObject();
2118           },
2119           py::arg("loc") = py::none(), "Creates an empty module")
2120       .def_property_readonly(
2121           "context",
2122           [](PyModule &self) { return self.getContext().getObject(); },
2123           "Context that created the Module")
2124       .def_property_readonly(
2125           "operation",
2126           [](PyModule &self) {
2127             return PyOperation::forOperation(self.getContext(),
2128                                              mlirModuleGetOperation(self.get()),
2129                                              self.getRef().releaseObject())
2130                 .releaseObject();
2131           },
2132           "Accesses the module as an operation")
2133       .def_property_readonly(
2134           "body",
2135           [](PyModule &self) {
2136             PyOperationRef module_op = PyOperation::forOperation(
2137                 self.getContext(), mlirModuleGetOperation(self.get()),
2138                 self.getRef().releaseObject());
2139             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2140             return returnBlock;
2141           },
2142           "Return the block for this module")
2143       .def(
2144           "dump",
2145           [](PyModule &self) {
2146             mlirOperationDump(mlirModuleGetOperation(self.get()));
2147           },
2148           kDumpDocstring)
2149       .def(
2150           "__str__",
2151           [](PyModule &self) {
2152             MlirOperation operation = mlirModuleGetOperation(self.get());
2153             PyPrintAccumulator printAccum;
2154             mlirOperationPrint(operation, printAccum.getCallback(),
2155                                printAccum.getUserData());
2156             return printAccum.join();
2157           },
2158           kOperationStrDunderDocstring);
2159 
2160   //----------------------------------------------------------------------------
2161   // Mapping of Operation.
2162   //----------------------------------------------------------------------------
2163   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2164       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2165                              [](PyOperationBase &self) {
2166                                return self.getOperation().getCapsule();
2167                              })
2168       .def("__eq__",
2169            [](PyOperationBase &self, PyOperationBase &other) {
2170              return &self.getOperation() == &other.getOperation();
2171            })
2172       .def("__eq__",
2173            [](PyOperationBase &self, py::object other) { return false; })
2174       .def("__hash__",
2175            [](PyOperationBase &self) {
2176              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2177            })
2178       .def_property_readonly("attributes",
2179                              [](PyOperationBase &self) {
2180                                return PyOpAttributeMap(
2181                                    self.getOperation().getRef());
2182                              })
2183       .def_property_readonly("operands",
2184                              [](PyOperationBase &self) {
2185                                return PyOpOperandList(
2186                                    self.getOperation().getRef());
2187                              })
2188       .def_property_readonly("regions",
2189                              [](PyOperationBase &self) {
2190                                return PyRegionList(
2191                                    self.getOperation().getRef());
2192                              })
2193       .def_property_readonly(
2194           "results",
2195           [](PyOperationBase &self) {
2196             return PyOpResultList(self.getOperation().getRef());
2197           },
2198           "Returns the list of Operation results.")
2199       .def_property_readonly(
2200           "result",
2201           [](PyOperationBase &self) {
2202             auto &operation = self.getOperation();
2203             auto numResults = mlirOperationGetNumResults(operation);
2204             if (numResults != 1) {
2205               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2206               throw SetPyError(
2207                   PyExc_ValueError,
2208                   Twine("Cannot call .result on operation ") +
2209                       StringRef(name.data, name.length) + " which has " +
2210                       Twine(numResults) +
2211                       " results (it is only valid for operations with a "
2212                       "single result)");
2213             }
2214             return PyOpResult(operation.getRef(),
2215                               mlirOperationGetResult(operation, 0));
2216           },
2217           "Shortcut to get an op result if it has only one (throws an error "
2218           "otherwise).")
2219       .def_property_readonly(
2220           "location",
2221           [](PyOperationBase &self) {
2222             PyOperation &operation = self.getOperation();
2223             return PyLocation(operation.getContext(),
2224                               mlirOperationGetLocation(operation.get()));
2225           },
2226           "Returns the source location the operation was defined or derived "
2227           "from.")
2228       .def(
2229           "__str__",
2230           [](PyOperationBase &self) {
2231             return self.getAsm(/*binary=*/false,
2232                                /*largeElementsLimit=*/llvm::None,
2233                                /*enableDebugInfo=*/false,
2234                                /*prettyDebugInfo=*/false,
2235                                /*printGenericOpForm=*/false,
2236                                /*useLocalScope=*/false);
2237           },
2238           "Returns the assembly form of the operation.")
2239       .def("print", &PyOperationBase::print,
2240            // Careful: Lots of arguments must match up with print method.
2241            py::arg("file") = py::none(), py::arg("binary") = false,
2242            py::arg("large_elements_limit") = py::none(),
2243            py::arg("enable_debug_info") = false,
2244            py::arg("pretty_debug_info") = false,
2245            py::arg("print_generic_op_form") = false,
2246            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2247       .def("get_asm", &PyOperationBase::getAsm,
2248            // Careful: Lots of arguments must match up with get_asm method.
2249            py::arg("binary") = false,
2250            py::arg("large_elements_limit") = py::none(),
2251            py::arg("enable_debug_info") = false,
2252            py::arg("pretty_debug_info") = false,
2253            py::arg("print_generic_op_form") = false,
2254            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2255       .def(
2256           "verify",
2257           [](PyOperationBase &self) {
2258             return mlirOperationVerify(self.getOperation());
2259           },
2260           "Verify the operation and return true if it passes, false if it "
2261           "fails.")
2262       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2263            "Puts self immediately after the other operation in its parent "
2264            "block.")
2265       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2266            "Puts self immediately before the other operation in its parent "
2267            "block.")
2268       .def(
2269           "detach_from_parent",
2270           [](PyOperationBase &self) {
2271             PyOperation &operation = self.getOperation();
2272             operation.checkValid();
2273             if (!operation.isAttached())
2274               throw py::value_error("Detached operation has no parent.");
2275 
2276             operation.detachFromParent();
2277             return operation.createOpView();
2278           },
2279           "Detaches the operation from its parent block.");
2280 
2281   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2282       .def_static("create", &PyOperation::create, py::arg("name"),
2283                   py::arg("results") = py::none(),
2284                   py::arg("operands") = py::none(),
2285                   py::arg("attributes") = py::none(),
2286                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2287                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2288                   kOperationCreateDocstring)
2289       .def_property_readonly("parent",
2290                              [](PyOperation &self) -> py::object {
2291                                auto parent = self.getParentOperation();
2292                                if (parent)
2293                                  return parent->getObject();
2294                                return py::none();
2295                              })
2296       .def("erase", &PyOperation::erase)
2297       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2298                              &PyOperation::getCapsule)
2299       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2300       .def_property_readonly("name",
2301                              [](PyOperation &self) {
2302                                self.checkValid();
2303                                MlirOperation operation = self.get();
2304                                MlirStringRef name = mlirIdentifierStr(
2305                                    mlirOperationGetName(operation));
2306                                return py::str(name.data, name.length);
2307                              })
2308       .def_property_readonly(
2309           "context",
2310           [](PyOperation &self) {
2311             self.checkValid();
2312             return self.getContext().getObject();
2313           },
2314           "Context that owns the Operation")
2315       .def_property_readonly("opview", &PyOperation::createOpView);
2316 
2317   auto opViewClass =
2318       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2319           .def(py::init<py::object>())
2320           .def_property_readonly("operation", &PyOpView::getOperationObject)
2321           .def_property_readonly(
2322               "context",
2323               [](PyOpView &self) {
2324                 return self.getOperation().getContext().getObject();
2325               },
2326               "Context that owns the Operation")
2327           .def("__str__", [](PyOpView &self) {
2328             return py::str(self.getOperationObject());
2329           });
2330   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2331   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2332   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2333   opViewClass.attr("build_generic") = classmethod(
2334       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2335       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2336       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2337       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2338       "Builds a specific, generated OpView based on class level attributes.");
2339 
2340   //----------------------------------------------------------------------------
2341   // Mapping of PyRegion.
2342   //----------------------------------------------------------------------------
2343   py::class_<PyRegion>(m, "Region", py::module_local())
2344       .def_property_readonly(
2345           "blocks",
2346           [](PyRegion &self) {
2347             return PyBlockList(self.getParentOperation(), self.get());
2348           },
2349           "Returns a forward-optimized sequence of blocks.")
2350       .def_property_readonly(
2351           "owner",
2352           [](PyRegion &self) {
2353             return self.getParentOperation()->createOpView();
2354           },
2355           "Returns the operation owning this region.")
2356       .def(
2357           "__iter__",
2358           [](PyRegion &self) {
2359             self.checkValid();
2360             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2361             return PyBlockIterator(self.getParentOperation(), firstBlock);
2362           },
2363           "Iterates over blocks in the region.")
2364       .def("__eq__",
2365            [](PyRegion &self, PyRegion &other) {
2366              return self.get().ptr == other.get().ptr;
2367            })
2368       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2369 
2370   //----------------------------------------------------------------------------
2371   // Mapping of PyBlock.
2372   //----------------------------------------------------------------------------
2373   py::class_<PyBlock>(m, "Block", py::module_local())
2374       .def_property_readonly(
2375           "owner",
2376           [](PyBlock &self) {
2377             return self.getParentOperation()->createOpView();
2378           },
2379           "Returns the owning operation of this block.")
2380       .def_property_readonly(
2381           "region",
2382           [](PyBlock &self) {
2383             MlirRegion region = mlirBlockGetParentRegion(self.get());
2384             return PyRegion(self.getParentOperation(), region);
2385           },
2386           "Returns the owning region of this block.")
2387       .def_property_readonly(
2388           "arguments",
2389           [](PyBlock &self) {
2390             return PyBlockArgumentList(self.getParentOperation(), self.get());
2391           },
2392           "Returns a list of block arguments.")
2393       .def_property_readonly(
2394           "operations",
2395           [](PyBlock &self) {
2396             return PyOperationList(self.getParentOperation(), self.get());
2397           },
2398           "Returns a forward-optimized sequence of operations.")
2399       .def_static(
2400           "create_at_start",
2401           [](PyRegion &parent, py::list pyArgTypes) {
2402             parent.checkValid();
2403             llvm::SmallVector<MlirType, 4> argTypes;
2404             argTypes.reserve(pyArgTypes.size());
2405             for (auto &pyArg : pyArgTypes) {
2406               argTypes.push_back(pyArg.cast<PyType &>());
2407             }
2408 
2409             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2410             mlirRegionInsertOwnedBlock(parent, 0, block);
2411             return PyBlock(parent.getParentOperation(), block);
2412           },
2413           py::arg("parent"), py::arg("pyArgTypes") = py::list(),
2414           "Creates and returns a new Block at the beginning of the given "
2415           "region (with given argument types).")
2416       .def(
2417           "create_before",
2418           [](PyBlock &self, py::args pyArgTypes) {
2419             self.checkValid();
2420             llvm::SmallVector<MlirType, 4> argTypes;
2421             argTypes.reserve(pyArgTypes.size());
2422             for (auto &pyArg : pyArgTypes) {
2423               argTypes.push_back(pyArg.cast<PyType &>());
2424             }
2425 
2426             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2427             MlirRegion region = mlirBlockGetParentRegion(self.get());
2428             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2429             return PyBlock(self.getParentOperation(), block);
2430           },
2431           "Creates and returns a new Block before this block "
2432           "(with given argument types).")
2433       .def(
2434           "create_after",
2435           [](PyBlock &self, py::args pyArgTypes) {
2436             self.checkValid();
2437             llvm::SmallVector<MlirType, 4> argTypes;
2438             argTypes.reserve(pyArgTypes.size());
2439             for (auto &pyArg : pyArgTypes) {
2440               argTypes.push_back(pyArg.cast<PyType &>());
2441             }
2442 
2443             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2444             MlirRegion region = mlirBlockGetParentRegion(self.get());
2445             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2446             return PyBlock(self.getParentOperation(), block);
2447           },
2448           "Creates and returns a new Block after this block "
2449           "(with given argument types).")
2450       .def(
2451           "__iter__",
2452           [](PyBlock &self) {
2453             self.checkValid();
2454             MlirOperation firstOperation =
2455                 mlirBlockGetFirstOperation(self.get());
2456             return PyOperationIterator(self.getParentOperation(),
2457                                        firstOperation);
2458           },
2459           "Iterates over operations in the block.")
2460       .def("__eq__",
2461            [](PyBlock &self, PyBlock &other) {
2462              return self.get().ptr == other.get().ptr;
2463            })
2464       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2465       .def(
2466           "__str__",
2467           [](PyBlock &self) {
2468             self.checkValid();
2469             PyPrintAccumulator printAccum;
2470             mlirBlockPrint(self.get(), printAccum.getCallback(),
2471                            printAccum.getUserData());
2472             return printAccum.join();
2473           },
2474           "Returns the assembly form of the block.")
2475       .def(
2476           "append",
2477           [](PyBlock &self, PyOperationBase &operation) {
2478             if (operation.getOperation().isAttached())
2479               operation.getOperation().detachFromParent();
2480 
2481             MlirOperation mlirOperation = operation.getOperation().get();
2482             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2483             operation.getOperation().setAttached(
2484                 self.getParentOperation().getObject());
2485           },
2486           "Appends an operation to this block. If the operation is currently "
2487           "in another block, it will be moved.");
2488 
2489   //----------------------------------------------------------------------------
2490   // Mapping of PyInsertionPoint.
2491   //----------------------------------------------------------------------------
2492 
2493   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2494       .def(py::init<PyBlock &>(), py::arg("block"),
2495            "Inserts after the last operation but still inside the block.")
2496       .def("__enter__", &PyInsertionPoint::contextEnter)
2497       .def("__exit__", &PyInsertionPoint::contextExit)
2498       .def_property_readonly_static(
2499           "current",
2500           [](py::object & /*class*/) {
2501             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2502             if (!ip)
2503               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2504             return ip;
2505           },
2506           "Gets the InsertionPoint bound to the current thread or raises "
2507           "ValueError if none has been set")
2508       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2509            "Inserts before a referenced operation.")
2510       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2511                   py::arg("block"), "Inserts at the beginning of the block.")
2512       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2513                   py::arg("block"), "Inserts before the block terminator.")
2514       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2515            "Inserts an operation.")
2516       .def_property_readonly(
2517           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2518           "Returns the block that this InsertionPoint points to.");
2519 
2520   //----------------------------------------------------------------------------
2521   // Mapping of PyAttribute.
2522   //----------------------------------------------------------------------------
2523   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2524       // Delegate to the PyAttribute copy constructor, which will also lifetime
2525       // extend the backing context which owns the MlirAttribute.
2526       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2527            "Casts the passed attribute to the generic Attribute")
2528       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2529                              &PyAttribute::getCapsule)
2530       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2531       .def_static(
2532           "parse",
2533           [](std::string attrSpec, DefaultingPyMlirContext context) {
2534             MlirAttribute type = mlirAttributeParseGet(
2535                 context->get(), toMlirStringRef(attrSpec));
2536             // TODO: Rework error reporting once diagnostic engine is exposed
2537             // in C API.
2538             if (mlirAttributeIsNull(type)) {
2539               throw SetPyError(PyExc_ValueError,
2540                                Twine("Unable to parse attribute: '") +
2541                                    attrSpec + "'");
2542             }
2543             return PyAttribute(context->getRef(), type);
2544           },
2545           py::arg("asm"), py::arg("context") = py::none(),
2546           "Parses an attribute from an assembly form")
2547       .def_property_readonly(
2548           "context",
2549           [](PyAttribute &self) { return self.getContext().getObject(); },
2550           "Context that owns the Attribute")
2551       .def_property_readonly("type",
2552                              [](PyAttribute &self) {
2553                                return PyType(self.getContext()->getRef(),
2554                                              mlirAttributeGetType(self));
2555                              })
2556       .def(
2557           "get_named",
2558           [](PyAttribute &self, std::string name) {
2559             return PyNamedAttribute(self, std::move(name));
2560           },
2561           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2562       .def("__eq__",
2563            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2564       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2565       .def("__hash__",
2566            [](PyAttribute &self) {
2567              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2568            })
2569       .def(
2570           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2571           kDumpDocstring)
2572       .def(
2573           "__str__",
2574           [](PyAttribute &self) {
2575             PyPrintAccumulator printAccum;
2576             mlirAttributePrint(self, printAccum.getCallback(),
2577                                printAccum.getUserData());
2578             return printAccum.join();
2579           },
2580           "Returns the assembly form of the Attribute.")
2581       .def("__repr__", [](PyAttribute &self) {
2582         // Generally, assembly formats are not printed for __repr__ because
2583         // this can cause exceptionally long debug output and exceptions.
2584         // However, attribute values are generally considered useful and are
2585         // printed. This may need to be re-evaluated if debug dumps end up
2586         // being excessive.
2587         PyPrintAccumulator printAccum;
2588         printAccum.parts.append("Attribute(");
2589         mlirAttributePrint(self, printAccum.getCallback(),
2590                            printAccum.getUserData());
2591         printAccum.parts.append(")");
2592         return printAccum.join();
2593       });
2594 
2595   //----------------------------------------------------------------------------
2596   // Mapping of PyNamedAttribute
2597   //----------------------------------------------------------------------------
2598   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2599       .def("__repr__",
2600            [](PyNamedAttribute &self) {
2601              PyPrintAccumulator printAccum;
2602              printAccum.parts.append("NamedAttribute(");
2603              printAccum.parts.append(
2604                  mlirIdentifierStr(self.namedAttr.name).data);
2605              printAccum.parts.append("=");
2606              mlirAttributePrint(self.namedAttr.attribute,
2607                                 printAccum.getCallback(),
2608                                 printAccum.getUserData());
2609              printAccum.parts.append(")");
2610              return printAccum.join();
2611            })
2612       .def_property_readonly(
2613           "name",
2614           [](PyNamedAttribute &self) {
2615             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2616                            mlirIdentifierStr(self.namedAttr.name).length);
2617           },
2618           "The name of the NamedAttribute binding")
2619       .def_property_readonly(
2620           "attr",
2621           [](PyNamedAttribute &self) {
2622             // TODO: When named attribute is removed/refactored, also remove
2623             // this constructor (it does an inefficient table lookup).
2624             auto contextRef = PyMlirContext::forContext(
2625                 mlirAttributeGetContext(self.namedAttr.attribute));
2626             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2627           },
2628           py::keep_alive<0, 1>(),
2629           "The underlying generic attribute of the NamedAttribute binding");
2630 
2631   //----------------------------------------------------------------------------
2632   // Mapping of PyType.
2633   //----------------------------------------------------------------------------
2634   py::class_<PyType>(m, "Type", py::module_local())
2635       // Delegate to the PyType copy constructor, which will also lifetime
2636       // extend the backing context which owns the MlirType.
2637       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2638            "Casts the passed type to the generic Type")
2639       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2640       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2641       .def_static(
2642           "parse",
2643           [](std::string typeSpec, DefaultingPyMlirContext context) {
2644             MlirType type =
2645                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2646             // TODO: Rework error reporting once diagnostic engine is exposed
2647             // in C API.
2648             if (mlirTypeIsNull(type)) {
2649               throw SetPyError(PyExc_ValueError,
2650                                Twine("Unable to parse type: '") + typeSpec +
2651                                    "'");
2652             }
2653             return PyType(context->getRef(), type);
2654           },
2655           py::arg("asm"), py::arg("context") = py::none(),
2656           kContextParseTypeDocstring)
2657       .def_property_readonly(
2658           "context", [](PyType &self) { return self.getContext().getObject(); },
2659           "Context that owns the Type")
2660       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2661       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2662       .def("__hash__",
2663            [](PyType &self) {
2664              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2665            })
2666       .def(
2667           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2668       .def(
2669           "__str__",
2670           [](PyType &self) {
2671             PyPrintAccumulator printAccum;
2672             mlirTypePrint(self, printAccum.getCallback(),
2673                           printAccum.getUserData());
2674             return printAccum.join();
2675           },
2676           "Returns the assembly form of the type.")
2677       .def("__repr__", [](PyType &self) {
2678         // Generally, assembly formats are not printed for __repr__ because
2679         // this can cause exceptionally long debug output and exceptions.
2680         // However, types are an exception as they typically have compact
2681         // assembly forms and printing them is useful.
2682         PyPrintAccumulator printAccum;
2683         printAccum.parts.append("Type(");
2684         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2685         printAccum.parts.append(")");
2686         return printAccum.join();
2687       });
2688 
2689   //----------------------------------------------------------------------------
2690   // Mapping of Value.
2691   //----------------------------------------------------------------------------
2692   py::class_<PyValue>(m, "Value", py::module_local())
2693       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2694       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2695       .def_property_readonly(
2696           "context",
2697           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2698           "Context in which the value lives.")
2699       .def(
2700           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2701           kDumpDocstring)
2702       .def_property_readonly(
2703           "owner",
2704           [](PyValue &self) {
2705             assert(mlirOperationEqual(self.getParentOperation()->get(),
2706                                       mlirOpResultGetOwner(self.get())) &&
2707                    "expected the owner of the value in Python to match that in "
2708                    "the IR");
2709             return self.getParentOperation().getObject();
2710           })
2711       .def("__eq__",
2712            [](PyValue &self, PyValue &other) {
2713              return self.get().ptr == other.get().ptr;
2714            })
2715       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2716       .def("__hash__",
2717            [](PyValue &self) {
2718              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2719            })
2720       .def(
2721           "__str__",
2722           [](PyValue &self) {
2723             PyPrintAccumulator printAccum;
2724             printAccum.parts.append("Value(");
2725             mlirValuePrint(self.get(), printAccum.getCallback(),
2726                            printAccum.getUserData());
2727             printAccum.parts.append(")");
2728             return printAccum.join();
2729           },
2730           kValueDunderStrDocstring)
2731       .def_property_readonly("type", [](PyValue &self) {
2732         return PyType(self.getParentOperation()->getContext(),
2733                       mlirValueGetType(self.get()));
2734       });
2735   PyBlockArgument::bind(m);
2736   PyOpResult::bind(m);
2737 
2738   //----------------------------------------------------------------------------
2739   // Mapping of SymbolTable.
2740   //----------------------------------------------------------------------------
2741   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2742       .def(py::init<PyOperationBase &>())
2743       .def("__getitem__", &PySymbolTable::dunderGetItem)
2744       .def("insert", &PySymbolTable::insert)
2745       .def("erase", &PySymbolTable::erase)
2746       .def("__delitem__", &PySymbolTable::dunderDel)
2747       .def("__contains__", [](PySymbolTable &table, const std::string &name) {
2748         return !mlirOperationIsNull(mlirSymbolTableLookup(
2749             table, mlirStringRefCreate(name.data(), name.length())));
2750       });
2751 
2752   // Container bindings.
2753   PyBlockArgumentList::bind(m);
2754   PyBlockIterator::bind(m);
2755   PyBlockList::bind(m);
2756   PyOperationIterator::bind(m);
2757   PyOperationList::bind(m);
2758   PyOpAttributeMap::bind(m);
2759   PyOpOperandList::bind(m);
2760   PyOpResultList::bind(m);
2761   PyRegionIterator::bind(m);
2762   PyRegionList::bind(m);
2763 
2764   // Debug bindings.
2765   PyGlobalDebugFlag::bind(m);
2766 }
2767