1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 namespace py = pybind11;
25 using namespace mlir;
26 using namespace mlir::python;
27 
28 using llvm::SmallVector;
29 using llvm::StringRef;
30 using llvm::Twine;
31 
32 //------------------------------------------------------------------------------
33 // Docstrings (trivial, non-duplicated docstrings are included inline).
34 //------------------------------------------------------------------------------
35 
36 static const char kContextParseTypeDocstring[] =
37     R"(Parses the assembly form of a type.
38 
39 Returns a Type object or raises a ValueError if the type cannot be parsed.
40 
41 See also: https://mlir.llvm.org/docs/LangRef/#type-system
42 )";
43 
44 static const char kContextGetCallSiteLocationDocstring[] =
45     R"(Gets a Location representing a caller and callsite)";
46 
47 static const char kContextGetFileLocationDocstring[] =
48     R"(Gets a Location representing a file, line and column)";
49 
50 static const char kContextGetNameLocationDocString[] =
51     R"(Gets a Location representing a named location with optional child location)";
52 
53 static const char kModuleParseDocstring[] =
54     R"(Parses a module's assembly format from a string.
55 
56 Returns a new MlirModule or raises a ValueError if the parsing fails.
57 
58 See also: https://mlir.llvm.org/docs/LangRef/
59 )";
60 
61 static const char kOperationCreateDocstring[] =
62     R"(Creates a new operation.
63 
64 Args:
65   name: Operation name (e.g. "dialect.operation").
66   results: Sequence of Type representing op result types.
67   attributes: Dict of str:Attribute.
68   successors: List of Block for the operation's successors.
69   regions: Number of regions to create.
70   location: A Location object (defaults to resolve from context manager).
71   ip: An InsertionPoint (defaults to resolve from context manager or set to
72     False to disable insertion, even with an insertion point set in the
73     context manager).
74 Returns:
75   A new "detached" Operation object. Detached operations can be added
76   to blocks, which causes them to become "attached."
77 )";
78 
79 static const char kOperationPrintDocstring[] =
80     R"(Prints the assembly form of the operation to a file like object.
81 
82 Args:
83   file: The file like object to write to. Defaults to sys.stdout.
84   binary: Whether to write bytes (True) or str (False). Defaults to False.
85   large_elements_limit: Whether to elide elements attributes above this
86     number of elements. Defaults to None (no limit).
87   enable_debug_info: Whether to print debug/location information. Defaults
88     to False.
89   pretty_debug_info: Whether to format debug information for easier reading
90     by a human (warning: the result is unparseable).
91   print_generic_op_form: Whether to print the generic assembly forms of all
92     ops. Defaults to False.
93   use_local_Scope: Whether to print in a way that is more optimized for
94     multi-threaded access but may not be consistent with how the overall
95     module prints.
96   assume_verified: By default, if not printing generic form, the verifier
97     will be run and if it fails, generic form will be printed with a comment
98     about failed verification. While a reasonable default for interactive use,
99     for systematic use, it is often better for the caller to verify explicitly
100     and report failures in a more robust fashion. Set this to True if doing this
101     in order to avoid running a redundant verification. If the IR is actually
102     invalid, behavior is undefined.
103 )";
104 
105 static const char kOperationGetAsmDocstring[] =
106     R"(Gets the assembly form of the operation with all options available.
107 
108 Args:
109   binary: Whether to return a bytes (True) or str (False) object. Defaults to
110     False.
111   ... others ...: See the print() method for common keyword arguments for
112     configuring the printout.
113 Returns:
114   Either a bytes or str object, depending on the setting of the 'binary'
115   argument.
116 )";
117 
118 static const char kOperationStrDunderDocstring[] =
119     R"(Gets the assembly form of the operation with default options.
120 
121 If more advanced control over the assembly formatting or I/O options is needed,
122 use the dedicated print or get_asm method, which supports keyword arguments to
123 customize behavior.
124 )";
125 
126 static const char kDumpDocstring[] =
127     R"(Dumps a debug representation of the object to stderr.)";
128 
129 static const char kAppendBlockDocstring[] =
130     R"(Appends a new block, with argument types as positional args.
131 
132 Returns:
133   The created block.
134 )";
135 
136 static const char kValueDunderStrDocstring[] =
137     R"(Returns the string form of the value.
138 
139 If the value is a block argument, this is the assembly form of its type and the
140 position in the argument list. If the value is an operation result, this is
141 equivalent to printing the operation that produced it.
142 )";
143 
144 //------------------------------------------------------------------------------
145 // Utilities.
146 //------------------------------------------------------------------------------
147 
148 /// Helper for creating an @classmethod.
149 template <class Func, typename... Args>
150 py::object classmethod(Func f, Args... args) {
151   py::object cf = py::cpp_function(f, args...);
152   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
153 }
154 
155 static py::object
156 createCustomDialectWrapper(const std::string &dialectNamespace,
157                            py::object dialectDescriptor) {
158   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
159   if (!dialectClass) {
160     // Use the base class.
161     return py::cast(PyDialect(std::move(dialectDescriptor)));
162   }
163 
164   // Create the custom implementation.
165   return (*dialectClass)(std::move(dialectDescriptor));
166 }
167 
168 static MlirStringRef toMlirStringRef(const std::string &s) {
169   return mlirStringRefCreate(s.data(), s.size());
170 }
171 
172 /// Wrapper for the global LLVM debugging flag.
173 struct PyGlobalDebugFlag {
174   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
175 
176   static bool get(py::object) { return mlirIsGlobalDebugEnabled(); }
177 
178   static void bind(py::module &m) {
179     // Debug flags.
180     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
181         .def_property_static("flag", &PyGlobalDebugFlag::get,
182                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
183   }
184 };
185 
186 //------------------------------------------------------------------------------
187 // Collections.
188 //------------------------------------------------------------------------------
189 
190 namespace {
191 
192 class PyRegionIterator {
193 public:
194   PyRegionIterator(PyOperationRef operation)
195       : operation(std::move(operation)) {}
196 
197   PyRegionIterator &dunderIter() { return *this; }
198 
199   PyRegion dunderNext() {
200     operation->checkValid();
201     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
202       throw py::stop_iteration();
203     }
204     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
205     return PyRegion(operation, region);
206   }
207 
208   static void bind(py::module &m) {
209     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
210         .def("__iter__", &PyRegionIterator::dunderIter)
211         .def("__next__", &PyRegionIterator::dunderNext);
212   }
213 
214 private:
215   PyOperationRef operation;
216   int nextIndex = 0;
217 };
218 
219 /// Regions of an op are fixed length and indexed numerically so are represented
220 /// with a sequence-like container.
221 class PyRegionList {
222 public:
223   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
224 
225   intptr_t dunderLen() {
226     operation->checkValid();
227     return mlirOperationGetNumRegions(operation->get());
228   }
229 
230   PyRegion dunderGetItem(intptr_t index) {
231     // dunderLen checks validity.
232     if (index < 0 || index >= dunderLen()) {
233       throw SetPyError(PyExc_IndexError,
234                        "attempt to access out of bounds region");
235     }
236     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
237     return PyRegion(operation, region);
238   }
239 
240   static void bind(py::module &m) {
241     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
242         .def("__len__", &PyRegionList::dunderLen)
243         .def("__getitem__", &PyRegionList::dunderGetItem);
244   }
245 
246 private:
247   PyOperationRef operation;
248 };
249 
250 class PyBlockIterator {
251 public:
252   PyBlockIterator(PyOperationRef operation, MlirBlock next)
253       : operation(std::move(operation)), next(next) {}
254 
255   PyBlockIterator &dunderIter() { return *this; }
256 
257   PyBlock dunderNext() {
258     operation->checkValid();
259     if (mlirBlockIsNull(next)) {
260       throw py::stop_iteration();
261     }
262 
263     PyBlock returnBlock(operation, next);
264     next = mlirBlockGetNextInRegion(next);
265     return returnBlock;
266   }
267 
268   static void bind(py::module &m) {
269     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
270         .def("__iter__", &PyBlockIterator::dunderIter)
271         .def("__next__", &PyBlockIterator::dunderNext);
272   }
273 
274 private:
275   PyOperationRef operation;
276   MlirBlock next;
277 };
278 
279 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
280 /// we present them as a more full-featured list-like container but optimize
281 /// it for forward iteration. Blocks are always owned by a region.
282 class PyBlockList {
283 public:
284   PyBlockList(PyOperationRef operation, MlirRegion region)
285       : operation(std::move(operation)), region(region) {}
286 
287   PyBlockIterator dunderIter() {
288     operation->checkValid();
289     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
290   }
291 
292   intptr_t dunderLen() {
293     operation->checkValid();
294     intptr_t count = 0;
295     MlirBlock block = mlirRegionGetFirstBlock(region);
296     while (!mlirBlockIsNull(block)) {
297       count += 1;
298       block = mlirBlockGetNextInRegion(block);
299     }
300     return count;
301   }
302 
303   PyBlock dunderGetItem(intptr_t index) {
304     operation->checkValid();
305     if (index < 0) {
306       throw SetPyError(PyExc_IndexError,
307                        "attempt to access out of bounds block");
308     }
309     MlirBlock block = mlirRegionGetFirstBlock(region);
310     while (!mlirBlockIsNull(block)) {
311       if (index == 0) {
312         return PyBlock(operation, block);
313       }
314       block = mlirBlockGetNextInRegion(block);
315       index -= 1;
316     }
317     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
318   }
319 
320   PyBlock appendBlock(py::args pyArgTypes) {
321     operation->checkValid();
322     llvm::SmallVector<MlirType, 4> argTypes;
323     argTypes.reserve(pyArgTypes.size());
324     for (auto &pyArg : pyArgTypes) {
325       argTypes.push_back(pyArg.cast<PyType &>());
326     }
327 
328     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
329     mlirRegionAppendOwnedBlock(region, block);
330     return PyBlock(operation, block);
331   }
332 
333   static void bind(py::module &m) {
334     py::class_<PyBlockList>(m, "BlockList", py::module_local())
335         .def("__getitem__", &PyBlockList::dunderGetItem)
336         .def("__iter__", &PyBlockList::dunderIter)
337         .def("__len__", &PyBlockList::dunderLen)
338         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
339   }
340 
341 private:
342   PyOperationRef operation;
343   MlirRegion region;
344 };
345 
346 class PyOperationIterator {
347 public:
348   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
349       : parentOperation(std::move(parentOperation)), next(next) {}
350 
351   PyOperationIterator &dunderIter() { return *this; }
352 
353   py::object dunderNext() {
354     parentOperation->checkValid();
355     if (mlirOperationIsNull(next)) {
356       throw py::stop_iteration();
357     }
358 
359     PyOperationRef returnOperation =
360         PyOperation::forOperation(parentOperation->getContext(), next);
361     next = mlirOperationGetNextInBlock(next);
362     return returnOperation->createOpView();
363   }
364 
365   static void bind(py::module &m) {
366     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
367         .def("__iter__", &PyOperationIterator::dunderIter)
368         .def("__next__", &PyOperationIterator::dunderNext);
369   }
370 
371 private:
372   PyOperationRef parentOperation;
373   MlirOperation next;
374 };
375 
376 /// Operations are exposed by the C-API as a forward-only linked list. In
377 /// Python, we present them as a more full-featured list-like container but
378 /// optimize it for forward iteration. Iterable operations are always owned
379 /// by a block.
380 class PyOperationList {
381 public:
382   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
383       : parentOperation(std::move(parentOperation)), block(block) {}
384 
385   PyOperationIterator dunderIter() {
386     parentOperation->checkValid();
387     return PyOperationIterator(parentOperation,
388                                mlirBlockGetFirstOperation(block));
389   }
390 
391   intptr_t dunderLen() {
392     parentOperation->checkValid();
393     intptr_t count = 0;
394     MlirOperation childOp = mlirBlockGetFirstOperation(block);
395     while (!mlirOperationIsNull(childOp)) {
396       count += 1;
397       childOp = mlirOperationGetNextInBlock(childOp);
398     }
399     return count;
400   }
401 
402   py::object dunderGetItem(intptr_t index) {
403     parentOperation->checkValid();
404     if (index < 0) {
405       throw SetPyError(PyExc_IndexError,
406                        "attempt to access out of bounds operation");
407     }
408     MlirOperation childOp = mlirBlockGetFirstOperation(block);
409     while (!mlirOperationIsNull(childOp)) {
410       if (index == 0) {
411         return PyOperation::forOperation(parentOperation->getContext(), childOp)
412             ->createOpView();
413       }
414       childOp = mlirOperationGetNextInBlock(childOp);
415       index -= 1;
416     }
417     throw SetPyError(PyExc_IndexError,
418                      "attempt to access out of bounds operation");
419   }
420 
421   static void bind(py::module &m) {
422     py::class_<PyOperationList>(m, "OperationList", py::module_local())
423         .def("__getitem__", &PyOperationList::dunderGetItem)
424         .def("__iter__", &PyOperationList::dunderIter)
425         .def("__len__", &PyOperationList::dunderLen);
426   }
427 
428 private:
429   PyOperationRef parentOperation;
430   MlirBlock block;
431 };
432 
433 } // namespace
434 
435 //------------------------------------------------------------------------------
436 // PyMlirContext
437 //------------------------------------------------------------------------------
438 
439 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
440   py::gil_scoped_acquire acquire;
441   auto &liveContexts = getLiveContexts();
442   liveContexts[context.ptr] = this;
443 }
444 
445 PyMlirContext::~PyMlirContext() {
446   // Note that the only public way to construct an instance is via the
447   // forContext method, which always puts the associated handle into
448   // liveContexts.
449   py::gil_scoped_acquire acquire;
450   getLiveContexts().erase(context.ptr);
451   mlirContextDestroy(context);
452 }
453 
454 py::object PyMlirContext::getCapsule() {
455   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
456 }
457 
458 py::object PyMlirContext::createFromCapsule(py::object capsule) {
459   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
460   if (mlirContextIsNull(rawContext))
461     throw py::error_already_set();
462   return forContext(rawContext).releaseObject();
463 }
464 
465 PyMlirContext *PyMlirContext::createNewContextForInit() {
466   MlirContext context = mlirContextCreate();
467   mlirRegisterAllDialects(context);
468   return new PyMlirContext(context);
469 }
470 
471 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
472   py::gil_scoped_acquire acquire;
473   auto &liveContexts = getLiveContexts();
474   auto it = liveContexts.find(context.ptr);
475   if (it == liveContexts.end()) {
476     // Create.
477     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
478     py::object pyRef = py::cast(unownedContextWrapper);
479     assert(pyRef && "cast to py::object failed");
480     liveContexts[context.ptr] = unownedContextWrapper;
481     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
482   }
483   // Use existing.
484   py::object pyRef = py::cast(it->second);
485   return PyMlirContextRef(it->second, std::move(pyRef));
486 }
487 
488 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
489   static LiveContextMap liveContexts;
490   return liveContexts;
491 }
492 
493 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
494 
495 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
496 
497 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
498 
499 pybind11::object PyMlirContext::contextEnter() {
500   return PyThreadContextEntry::pushContext(*this);
501 }
502 
503 void PyMlirContext::contextExit(pybind11::object excType,
504                                 pybind11::object excVal,
505                                 pybind11::object excTb) {
506   PyThreadContextEntry::popContext(*this);
507 }
508 
509 PyMlirContext &DefaultingPyMlirContext::resolve() {
510   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
511   if (!context) {
512     throw SetPyError(
513         PyExc_RuntimeError,
514         "An MLIR function requires a Context but none was provided in the call "
515         "or from the surrounding environment. Either pass to the function with "
516         "a 'context=' argument or establish a default using 'with Context():'");
517   }
518   return *context;
519 }
520 
521 //------------------------------------------------------------------------------
522 // PyThreadContextEntry management
523 //------------------------------------------------------------------------------
524 
525 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
526   static thread_local std::vector<PyThreadContextEntry> stack;
527   return stack;
528 }
529 
530 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
531   auto &stack = getStack();
532   if (stack.empty())
533     return nullptr;
534   return &stack.back();
535 }
536 
537 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
538                                 py::object insertionPoint,
539                                 py::object location) {
540   auto &stack = getStack();
541   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
542                      std::move(location));
543   // If the new stack has more than one entry and the context of the new top
544   // entry matches the previous, copy the insertionPoint and location from the
545   // previous entry if missing from the new top entry.
546   if (stack.size() > 1) {
547     auto &prev = *(stack.rbegin() + 1);
548     auto &current = stack.back();
549     if (current.context.is(prev.context)) {
550       // Default non-context objects from the previous entry.
551       if (!current.insertionPoint)
552         current.insertionPoint = prev.insertionPoint;
553       if (!current.location)
554         current.location = prev.location;
555     }
556   }
557 }
558 
559 PyMlirContext *PyThreadContextEntry::getContext() {
560   if (!context)
561     return nullptr;
562   return py::cast<PyMlirContext *>(context);
563 }
564 
565 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
566   if (!insertionPoint)
567     return nullptr;
568   return py::cast<PyInsertionPoint *>(insertionPoint);
569 }
570 
571 PyLocation *PyThreadContextEntry::getLocation() {
572   if (!location)
573     return nullptr;
574   return py::cast<PyLocation *>(location);
575 }
576 
577 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
578   auto *tos = getTopOfStack();
579   return tos ? tos->getContext() : nullptr;
580 }
581 
582 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
583   auto *tos = getTopOfStack();
584   return tos ? tos->getInsertionPoint() : nullptr;
585 }
586 
587 PyLocation *PyThreadContextEntry::getDefaultLocation() {
588   auto *tos = getTopOfStack();
589   return tos ? tos->getLocation() : nullptr;
590 }
591 
592 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
593   py::object contextObj = py::cast(context);
594   push(FrameKind::Context, /*context=*/contextObj,
595        /*insertionPoint=*/py::object(),
596        /*location=*/py::object());
597   return contextObj;
598 }
599 
600 void PyThreadContextEntry::popContext(PyMlirContext &context) {
601   auto &stack = getStack();
602   if (stack.empty())
603     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
604   auto &tos = stack.back();
605   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
606     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
607   stack.pop_back();
608 }
609 
610 py::object
611 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
612   py::object contextObj =
613       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
614   py::object insertionPointObj = py::cast(insertionPoint);
615   push(FrameKind::InsertionPoint,
616        /*context=*/contextObj,
617        /*insertionPoint=*/insertionPointObj,
618        /*location=*/py::object());
619   return insertionPointObj;
620 }
621 
622 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
623   auto &stack = getStack();
624   if (stack.empty())
625     throw SetPyError(PyExc_RuntimeError,
626                      "Unbalanced InsertionPoint enter/exit");
627   auto &tos = stack.back();
628   if (tos.frameKind != FrameKind::InsertionPoint &&
629       tos.getInsertionPoint() != &insertionPoint)
630     throw SetPyError(PyExc_RuntimeError,
631                      "Unbalanced InsertionPoint enter/exit");
632   stack.pop_back();
633 }
634 
635 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
636   py::object contextObj = location.getContext().getObject();
637   py::object locationObj = py::cast(location);
638   push(FrameKind::Location, /*context=*/contextObj,
639        /*insertionPoint=*/py::object(),
640        /*location=*/locationObj);
641   return locationObj;
642 }
643 
644 void PyThreadContextEntry::popLocation(PyLocation &location) {
645   auto &stack = getStack();
646   if (stack.empty())
647     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
648   auto &tos = stack.back();
649   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
650     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
651   stack.pop_back();
652 }
653 
654 //------------------------------------------------------------------------------
655 // PyDialect, PyDialectDescriptor, PyDialects
656 //------------------------------------------------------------------------------
657 
658 MlirDialect PyDialects::getDialectForKey(const std::string &key,
659                                          bool attrError) {
660   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
661                                                     {key.data(), key.size()});
662   if (mlirDialectIsNull(dialect)) {
663     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
664                      Twine("Dialect '") + key + "' not found");
665   }
666   return dialect;
667 }
668 
669 //------------------------------------------------------------------------------
670 // PyLocation
671 //------------------------------------------------------------------------------
672 
673 py::object PyLocation::getCapsule() {
674   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
675 }
676 
677 PyLocation PyLocation::createFromCapsule(py::object capsule) {
678   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
679   if (mlirLocationIsNull(rawLoc))
680     throw py::error_already_set();
681   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
682                     rawLoc);
683 }
684 
685 py::object PyLocation::contextEnter() {
686   return PyThreadContextEntry::pushLocation(*this);
687 }
688 
689 void PyLocation::contextExit(py::object excType, py::object excVal,
690                              py::object excTb) {
691   PyThreadContextEntry::popLocation(*this);
692 }
693 
694 PyLocation &DefaultingPyLocation::resolve() {
695   auto *location = PyThreadContextEntry::getDefaultLocation();
696   if (!location) {
697     throw SetPyError(
698         PyExc_RuntimeError,
699         "An MLIR function requires a Location but none was provided in the "
700         "call or from the surrounding environment. Either pass to the function "
701         "with a 'loc=' argument or establish a default using 'with loc:'");
702   }
703   return *location;
704 }
705 
706 //------------------------------------------------------------------------------
707 // PyModule
708 //------------------------------------------------------------------------------
709 
710 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
711     : BaseContextObject(std::move(contextRef)), module(module) {}
712 
713 PyModule::~PyModule() {
714   py::gil_scoped_acquire acquire;
715   auto &liveModules = getContext()->liveModules;
716   assert(liveModules.count(module.ptr) == 1 &&
717          "destroying module not in live map");
718   liveModules.erase(module.ptr);
719   mlirModuleDestroy(module);
720 }
721 
722 PyModuleRef PyModule::forModule(MlirModule module) {
723   MlirContext context = mlirModuleGetContext(module);
724   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
725 
726   py::gil_scoped_acquire acquire;
727   auto &liveModules = contextRef->liveModules;
728   auto it = liveModules.find(module.ptr);
729   if (it == liveModules.end()) {
730     // Create.
731     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
732     // Note that the default return value policy on cast is automatic_reference,
733     // which does not take ownership (delete will not be called).
734     // Just be explicit.
735     py::object pyRef =
736         py::cast(unownedModule, py::return_value_policy::take_ownership);
737     unownedModule->handle = pyRef;
738     liveModules[module.ptr] =
739         std::make_pair(unownedModule->handle, unownedModule);
740     return PyModuleRef(unownedModule, std::move(pyRef));
741   }
742   // Use existing.
743   PyModule *existing = it->second.second;
744   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
745   return PyModuleRef(existing, std::move(pyRef));
746 }
747 
748 py::object PyModule::createFromCapsule(py::object capsule) {
749   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
750   if (mlirModuleIsNull(rawModule))
751     throw py::error_already_set();
752   return forModule(rawModule).releaseObject();
753 }
754 
755 py::object PyModule::getCapsule() {
756   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
757 }
758 
759 //------------------------------------------------------------------------------
760 // PyOperation
761 //------------------------------------------------------------------------------
762 
763 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
764     : BaseContextObject(std::move(contextRef)), operation(operation) {}
765 
766 PyOperation::~PyOperation() {
767   // If the operation has already been invalidated there is nothing to do.
768   if (!valid)
769     return;
770   auto &liveOperations = getContext()->liveOperations;
771   assert(liveOperations.count(operation.ptr) == 1 &&
772          "destroying operation not in live map");
773   liveOperations.erase(operation.ptr);
774   if (!isAttached()) {
775     mlirOperationDestroy(operation);
776   }
777 }
778 
779 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
780                                            MlirOperation operation,
781                                            py::object parentKeepAlive) {
782   auto &liveOperations = contextRef->liveOperations;
783   // Create.
784   PyOperation *unownedOperation =
785       new PyOperation(std::move(contextRef), operation);
786   // Note that the default return value policy on cast is automatic_reference,
787   // which does not take ownership (delete will not be called).
788   // Just be explicit.
789   py::object pyRef =
790       py::cast(unownedOperation, py::return_value_policy::take_ownership);
791   unownedOperation->handle = pyRef;
792   if (parentKeepAlive) {
793     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
794   }
795   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
796   return PyOperationRef(unownedOperation, std::move(pyRef));
797 }
798 
799 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
800                                          MlirOperation operation,
801                                          py::object parentKeepAlive) {
802   auto &liveOperations = contextRef->liveOperations;
803   auto it = liveOperations.find(operation.ptr);
804   if (it == liveOperations.end()) {
805     // Create.
806     return createInstance(std::move(contextRef), operation,
807                           std::move(parentKeepAlive));
808   }
809   // Use existing.
810   PyOperation *existing = it->second.second;
811   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
812   return PyOperationRef(existing, std::move(pyRef));
813 }
814 
815 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
816                                            MlirOperation operation,
817                                            py::object parentKeepAlive) {
818   auto &liveOperations = contextRef->liveOperations;
819   assert(liveOperations.count(operation.ptr) == 0 &&
820          "cannot create detached operation that already exists");
821   (void)liveOperations;
822 
823   PyOperationRef created = createInstance(std::move(contextRef), operation,
824                                           std::move(parentKeepAlive));
825   created->attached = false;
826   return created;
827 }
828 
829 void PyOperation::checkValid() const {
830   if (!valid) {
831     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
832   }
833 }
834 
835 void PyOperationBase::print(py::object fileObject, bool binary,
836                             llvm::Optional<int64_t> largeElementsLimit,
837                             bool enableDebugInfo, bool prettyDebugInfo,
838                             bool printGenericOpForm, bool useLocalScope,
839                             bool assumeVerified) {
840   PyOperation &operation = getOperation();
841   operation.checkValid();
842   if (fileObject.is_none())
843     fileObject = py::module::import("sys").attr("stdout");
844 
845   if (!assumeVerified && !printGenericOpForm &&
846       !mlirOperationVerify(operation)) {
847     std::string message("// Verification failed, printing generic form\n");
848     if (binary) {
849       fileObject.attr("write")(py::bytes(message));
850     } else {
851       fileObject.attr("write")(py::str(message));
852     }
853     printGenericOpForm = true;
854   }
855 
856   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
857   if (largeElementsLimit)
858     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
859   if (enableDebugInfo)
860     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
861   if (printGenericOpForm)
862     mlirOpPrintingFlagsPrintGenericOpForm(flags);
863 
864   PyFileAccumulator accum(fileObject, binary);
865   py::gil_scoped_release();
866   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
867                               accum.getUserData());
868   mlirOpPrintingFlagsDestroy(flags);
869 }
870 
871 py::object PyOperationBase::getAsm(bool binary,
872                                    llvm::Optional<int64_t> largeElementsLimit,
873                                    bool enableDebugInfo, bool prettyDebugInfo,
874                                    bool printGenericOpForm, bool useLocalScope,
875                                    bool assumeVerified) {
876   py::object fileObject;
877   if (binary) {
878     fileObject = py::module::import("io").attr("BytesIO")();
879   } else {
880     fileObject = py::module::import("io").attr("StringIO")();
881   }
882   print(fileObject, /*binary=*/binary,
883         /*largeElementsLimit=*/largeElementsLimit,
884         /*enableDebugInfo=*/enableDebugInfo,
885         /*prettyDebugInfo=*/prettyDebugInfo,
886         /*printGenericOpForm=*/printGenericOpForm,
887         /*useLocalScope=*/useLocalScope,
888         /*assumeVerified=*/assumeVerified);
889 
890   return fileObject.attr("getvalue")();
891 }
892 
893 void PyOperationBase::moveAfter(PyOperationBase &other) {
894   PyOperation &operation = getOperation();
895   PyOperation &otherOp = other.getOperation();
896   operation.checkValid();
897   otherOp.checkValid();
898   mlirOperationMoveAfter(operation, otherOp);
899   operation.parentKeepAlive = otherOp.parentKeepAlive;
900 }
901 
902 void PyOperationBase::moveBefore(PyOperationBase &other) {
903   PyOperation &operation = getOperation();
904   PyOperation &otherOp = other.getOperation();
905   operation.checkValid();
906   otherOp.checkValid();
907   mlirOperationMoveBefore(operation, otherOp);
908   operation.parentKeepAlive = otherOp.parentKeepAlive;
909 }
910 
911 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
912   checkValid();
913   if (!isAttached())
914     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
915   MlirOperation operation = mlirOperationGetParentOperation(get());
916   if (mlirOperationIsNull(operation))
917     return {};
918   return PyOperation::forOperation(getContext(), operation);
919 }
920 
921 PyBlock PyOperation::getBlock() {
922   checkValid();
923   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
924   MlirBlock block = mlirOperationGetBlock(get());
925   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
926   assert(parentOperation && "Operation has no parent");
927   return PyBlock{std::move(*parentOperation), block};
928 }
929 
930 py::object PyOperation::getCapsule() {
931   checkValid();
932   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
933 }
934 
935 py::object PyOperation::createFromCapsule(py::object capsule) {
936   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
937   if (mlirOperationIsNull(rawOperation))
938     throw py::error_already_set();
939   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
940   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
941       .releaseObject();
942 }
943 
944 py::object PyOperation::create(
945     std::string name, llvm::Optional<std::vector<PyType *>> results,
946     llvm::Optional<std::vector<PyValue *>> operands,
947     llvm::Optional<py::dict> attributes,
948     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
949     DefaultingPyLocation location, py::object maybeIp) {
950   llvm::SmallVector<MlirValue, 4> mlirOperands;
951   llvm::SmallVector<MlirType, 4> mlirResults;
952   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
954 
955   // General parameter validation.
956   if (regions < 0)
957     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
958 
959   // Unpack/validate operands.
960   if (operands) {
961     mlirOperands.reserve(operands->size());
962     for (PyValue *operand : *operands) {
963       if (!operand)
964         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
965       mlirOperands.push_back(operand->get());
966     }
967   }
968 
969   // Unpack/validate results.
970   if (results) {
971     mlirResults.reserve(results->size());
972     for (PyType *result : *results) {
973       // TODO: Verify result type originate from the same context.
974       if (!result)
975         throw SetPyError(PyExc_ValueError, "result type cannot be None");
976       mlirResults.push_back(*result);
977     }
978   }
979   // Unpack/validate attributes.
980   if (attributes) {
981     mlirAttributes.reserve(attributes->size());
982     for (auto &it : *attributes) {
983       std::string key;
984       try {
985         key = it.first.cast<std::string>();
986       } catch (py::cast_error &err) {
987         std::string msg = "Invalid attribute key (not a string) when "
988                           "attempting to create the operation \"" +
989                           name + "\" (" + err.what() + ")";
990         throw py::cast_error(msg);
991       }
992       try {
993         auto &attribute = it.second.cast<PyAttribute &>();
994         // TODO: Verify attribute originates from the same context.
995         mlirAttributes.emplace_back(std::move(key), attribute);
996       } catch (py::reference_cast_error &) {
997         // This exception seems thrown when the value is "None".
998         std::string msg =
999             "Found an invalid (`None`?) attribute value for the key \"" + key +
1000             "\" when attempting to create the operation \"" + name + "\"";
1001         throw py::cast_error(msg);
1002       } catch (py::cast_error &err) {
1003         std::string msg = "Invalid attribute value for the key \"" + key +
1004                           "\" when attempting to create the operation \"" +
1005                           name + "\" (" + err.what() + ")";
1006         throw py::cast_error(msg);
1007       }
1008     }
1009   }
1010   // Unpack/validate successors.
1011   if (successors) {
1012     mlirSuccessors.reserve(successors->size());
1013     for (auto *successor : *successors) {
1014       // TODO: Verify successor originate from the same context.
1015       if (!successor)
1016         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1017       mlirSuccessors.push_back(successor->get());
1018     }
1019   }
1020 
1021   // Apply unpacked/validated to the operation state. Beyond this
1022   // point, exceptions cannot be thrown or else the state will leak.
1023   MlirOperationState state =
1024       mlirOperationStateGet(toMlirStringRef(name), location);
1025   if (!mlirOperands.empty())
1026     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1027                                   mlirOperands.data());
1028   if (!mlirResults.empty())
1029     mlirOperationStateAddResults(&state, mlirResults.size(),
1030                                  mlirResults.data());
1031   if (!mlirAttributes.empty()) {
1032     // Note that the attribute names directly reference bytes in
1033     // mlirAttributes, so that vector must not be changed from here
1034     // on.
1035     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1036     mlirNamedAttributes.reserve(mlirAttributes.size());
1037     for (auto &it : mlirAttributes)
1038       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1039           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1040                             toMlirStringRef(it.first)),
1041           it.second));
1042     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1043                                     mlirNamedAttributes.data());
1044   }
1045   if (!mlirSuccessors.empty())
1046     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1047                                     mlirSuccessors.data());
1048   if (regions) {
1049     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1050     mlirRegions.resize(regions);
1051     for (int i = 0; i < regions; ++i)
1052       mlirRegions[i] = mlirRegionCreate();
1053     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1054                                       mlirRegions.data());
1055   }
1056 
1057   // Construct the operation.
1058   MlirOperation operation = mlirOperationCreate(&state);
1059   PyOperationRef created =
1060       PyOperation::createDetached(location->getContext(), operation);
1061 
1062   // InsertPoint active?
1063   if (!maybeIp.is(py::cast(false))) {
1064     PyInsertionPoint *ip;
1065     if (maybeIp.is_none()) {
1066       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1067     } else {
1068       ip = py::cast<PyInsertionPoint *>(maybeIp);
1069     }
1070     if (ip)
1071       ip->insert(*created.get());
1072   }
1073 
1074   return created->createOpView();
1075 }
1076 
1077 py::object PyOperation::createOpView() {
1078   checkValid();
1079   MlirIdentifier ident = mlirOperationGetName(get());
1080   MlirStringRef identStr = mlirIdentifierStr(ident);
1081   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1082       StringRef(identStr.data, identStr.length));
1083   if (opViewClass)
1084     return (*opViewClass)(getRef().getObject());
1085   return py::cast(PyOpView(getRef().getObject()));
1086 }
1087 
1088 void PyOperation::erase() {
1089   checkValid();
1090   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1091   // Python reference to a child operation is live. All children should also
1092   // have their `valid` bit set to false.
1093   auto &liveOperations = getContext()->liveOperations;
1094   if (liveOperations.count(operation.ptr))
1095     liveOperations.erase(operation.ptr);
1096   mlirOperationDestroy(operation);
1097   valid = false;
1098 }
1099 
1100 //------------------------------------------------------------------------------
1101 // PyOpView
1102 //------------------------------------------------------------------------------
1103 
1104 py::object
1105 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1106                        py::list operandList,
1107                        llvm::Optional<py::dict> attributes,
1108                        llvm::Optional<std::vector<PyBlock *>> successors,
1109                        llvm::Optional<int> regions,
1110                        DefaultingPyLocation location, py::object maybeIp) {
1111   PyMlirContextRef context = location->getContext();
1112   // Class level operation construction metadata.
1113   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1114   // Operand and result segment specs are either none, which does no
1115   // variadic unpacking, or a list of ints with segment sizes, where each
1116   // element is either a positive number (typically 1 for a scalar) or -1 to
1117   // indicate that it is derived from the length of the same-indexed operand
1118   // or result (implying that it is a list at that position).
1119   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1120   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1121 
1122   std::vector<uint32_t> operandSegmentLengths;
1123   std::vector<uint32_t> resultSegmentLengths;
1124 
1125   // Validate/determine region count.
1126   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1127   int opMinRegionCount = std::get<0>(opRegionSpec);
1128   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1129   if (!regions) {
1130     regions = opMinRegionCount;
1131   }
1132   if (*regions < opMinRegionCount) {
1133     throw py::value_error(
1134         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1135          llvm::Twine(opMinRegionCount) +
1136          " regions but was built with regions=" + llvm::Twine(*regions))
1137             .str());
1138   }
1139   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1140     throw py::value_error(
1141         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1142          llvm::Twine(opMinRegionCount) +
1143          " regions but was built with regions=" + llvm::Twine(*regions))
1144             .str());
1145   }
1146 
1147   // Unpack results.
1148   std::vector<PyType *> resultTypes;
1149   resultTypes.reserve(resultTypeList.size());
1150   if (resultSegmentSpecObj.is_none()) {
1151     // Non-variadic result unpacking.
1152     for (auto it : llvm::enumerate(resultTypeList)) {
1153       try {
1154         resultTypes.push_back(py::cast<PyType *>(it.value()));
1155         if (!resultTypes.back())
1156           throw py::cast_error();
1157       } catch (py::cast_error &err) {
1158         throw py::value_error((llvm::Twine("Result ") +
1159                                llvm::Twine(it.index()) + " of operation \"" +
1160                                name + "\" must be a Type (" + err.what() + ")")
1161                                   .str());
1162       }
1163     }
1164   } else {
1165     // Sized result unpacking.
1166     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1167     if (resultSegmentSpec.size() != resultTypeList.size()) {
1168       throw py::value_error((llvm::Twine("Operation \"") + name +
1169                              "\" requires " +
1170                              llvm::Twine(resultSegmentSpec.size()) +
1171                              " result segments but was provided " +
1172                              llvm::Twine(resultTypeList.size()))
1173                                 .str());
1174     }
1175     resultSegmentLengths.reserve(resultTypeList.size());
1176     for (auto it :
1177          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1178       int segmentSpec = std::get<1>(it.value());
1179       if (segmentSpec == 1 || segmentSpec == 0) {
1180         // Unpack unary element.
1181         try {
1182           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1183           if (resultType) {
1184             resultTypes.push_back(resultType);
1185             resultSegmentLengths.push_back(1);
1186           } else if (segmentSpec == 0) {
1187             // Allowed to be optional.
1188             resultSegmentLengths.push_back(0);
1189           } else {
1190             throw py::cast_error("was None and result is not optional");
1191           }
1192         } catch (py::cast_error &err) {
1193           throw py::value_error((llvm::Twine("Result ") +
1194                                  llvm::Twine(it.index()) + " of operation \"" +
1195                                  name + "\" must be a Type (" + err.what() +
1196                                  ")")
1197                                     .str());
1198         }
1199       } else if (segmentSpec == -1) {
1200         // Unpack sequence by appending.
1201         try {
1202           if (std::get<0>(it.value()).is_none()) {
1203             // Treat it as an empty list.
1204             resultSegmentLengths.push_back(0);
1205           } else {
1206             // Unpack the list.
1207             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1208             for (py::object segmentItem : segment) {
1209               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1210               if (!resultTypes.back()) {
1211                 throw py::cast_error("contained a None item");
1212               }
1213             }
1214             resultSegmentLengths.push_back(segment.size());
1215           }
1216         } catch (std::exception &err) {
1217           // NOTE: Sloppy to be using a catch-all here, but there are at least
1218           // three different unrelated exceptions that can be thrown in the
1219           // above "casts". Just keep the scope above small and catch them all.
1220           throw py::value_error((llvm::Twine("Result ") +
1221                                  llvm::Twine(it.index()) + " of operation \"" +
1222                                  name + "\" must be a Sequence of Types (" +
1223                                  err.what() + ")")
1224                                     .str());
1225         }
1226       } else {
1227         throw py::value_error("Unexpected segment spec");
1228       }
1229     }
1230   }
1231 
1232   // Unpack operands.
1233   std::vector<PyValue *> operands;
1234   operands.reserve(operands.size());
1235   if (operandSegmentSpecObj.is_none()) {
1236     // Non-sized operand unpacking.
1237     for (auto it : llvm::enumerate(operandList)) {
1238       try {
1239         operands.push_back(py::cast<PyValue *>(it.value()));
1240         if (!operands.back())
1241           throw py::cast_error();
1242       } catch (py::cast_error &err) {
1243         throw py::value_error((llvm::Twine("Operand ") +
1244                                llvm::Twine(it.index()) + " of operation \"" +
1245                                name + "\" must be a Value (" + err.what() + ")")
1246                                   .str());
1247       }
1248     }
1249   } else {
1250     // Sized operand unpacking.
1251     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1252     if (operandSegmentSpec.size() != operandList.size()) {
1253       throw py::value_error((llvm::Twine("Operation \"") + name +
1254                              "\" requires " +
1255                              llvm::Twine(operandSegmentSpec.size()) +
1256                              "operand segments but was provided " +
1257                              llvm::Twine(operandList.size()))
1258                                 .str());
1259     }
1260     operandSegmentLengths.reserve(operandList.size());
1261     for (auto it :
1262          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1263       int segmentSpec = std::get<1>(it.value());
1264       if (segmentSpec == 1 || segmentSpec == 0) {
1265         // Unpack unary element.
1266         try {
1267           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1268           if (operandValue) {
1269             operands.push_back(operandValue);
1270             operandSegmentLengths.push_back(1);
1271           } else if (segmentSpec == 0) {
1272             // Allowed to be optional.
1273             operandSegmentLengths.push_back(0);
1274           } else {
1275             throw py::cast_error("was None and operand is not optional");
1276           }
1277         } catch (py::cast_error &err) {
1278           throw py::value_error((llvm::Twine("Operand ") +
1279                                  llvm::Twine(it.index()) + " of operation \"" +
1280                                  name + "\" must be a Value (" + err.what() +
1281                                  ")")
1282                                     .str());
1283         }
1284       } else if (segmentSpec == -1) {
1285         // Unpack sequence by appending.
1286         try {
1287           if (std::get<0>(it.value()).is_none()) {
1288             // Treat it as an empty list.
1289             operandSegmentLengths.push_back(0);
1290           } else {
1291             // Unpack the list.
1292             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1293             for (py::object segmentItem : segment) {
1294               operands.push_back(py::cast<PyValue *>(segmentItem));
1295               if (!operands.back()) {
1296                 throw py::cast_error("contained a None item");
1297               }
1298             }
1299             operandSegmentLengths.push_back(segment.size());
1300           }
1301         } catch (std::exception &err) {
1302           // NOTE: Sloppy to be using a catch-all here, but there are at least
1303           // three different unrelated exceptions that can be thrown in the
1304           // above "casts". Just keep the scope above small and catch them all.
1305           throw py::value_error((llvm::Twine("Operand ") +
1306                                  llvm::Twine(it.index()) + " of operation \"" +
1307                                  name + "\" must be a Sequence of Values (" +
1308                                  err.what() + ")")
1309                                     .str());
1310         }
1311       } else {
1312         throw py::value_error("Unexpected segment spec");
1313       }
1314     }
1315   }
1316 
1317   // Merge operand/result segment lengths into attributes if needed.
1318   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1319     // Dup.
1320     if (attributes) {
1321       attributes = py::dict(*attributes);
1322     } else {
1323       attributes = py::dict();
1324     }
1325     if (attributes->contains("result_segment_sizes") ||
1326         attributes->contains("operand_segment_sizes")) {
1327       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1328                             "'operand_segment_sizes' attribute is unsupported. "
1329                             "Use Operation.create for such low-level access.");
1330     }
1331 
1332     // Add result_segment_sizes attribute.
1333     if (!resultSegmentLengths.empty()) {
1334       int64_t size = resultSegmentLengths.size();
1335       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1336           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1337           resultSegmentLengths.size(), resultSegmentLengths.data());
1338       (*attributes)["result_segment_sizes"] =
1339           PyAttribute(context, segmentLengthAttr);
1340     }
1341 
1342     // Add operand_segment_sizes attribute.
1343     if (!operandSegmentLengths.empty()) {
1344       int64_t size = operandSegmentLengths.size();
1345       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1346           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1347           operandSegmentLengths.size(), operandSegmentLengths.data());
1348       (*attributes)["operand_segment_sizes"] =
1349           PyAttribute(context, segmentLengthAttr);
1350     }
1351   }
1352 
1353   // Delegate to create.
1354   return PyOperation::create(std::move(name),
1355                              /*results=*/std::move(resultTypes),
1356                              /*operands=*/std::move(operands),
1357                              /*attributes=*/std::move(attributes),
1358                              /*successors=*/std::move(successors),
1359                              /*regions=*/*regions, location, maybeIp);
1360 }
1361 
1362 PyOpView::PyOpView(py::object operationObject)
1363     // Casting through the PyOperationBase base-class and then back to the
1364     // Operation lets us accept any PyOperationBase subclass.
1365     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1366       operationObject(operation.getRef().getObject()) {}
1367 
1368 py::object PyOpView::createRawSubclass(py::object userClass) {
1369   // This is... a little gross. The typical pattern is to have a pure python
1370   // class that extends OpView like:
1371   //   class AddFOp(_cext.ir.OpView):
1372   //     def __init__(self, loc, lhs, rhs):
1373   //       operation = loc.context.create_operation(
1374   //           "addf", lhs, rhs, results=[lhs.type])
1375   //       super().__init__(operation)
1376   //
1377   // I.e. The goal of the user facing type is to provide a nice constructor
1378   // that has complete freedom for the op under construction. This is at odds
1379   // with our other desire to sometimes create this object by just passing an
1380   // operation (to initialize the base class). We could do *arg and **kwargs
1381   // munging to try to make it work, but instead, we synthesize a new class
1382   // on the fly which extends this user class (AddFOp in this example) and
1383   // *give it* the base class's __init__ method, thus bypassing the
1384   // intermediate subclass's __init__ method entirely. While slightly,
1385   // underhanded, this is safe/legal because the type hierarchy has not changed
1386   // (we just added a new leaf) and we aren't mucking around with __new__.
1387   // Typically, this new class will be stored on the original as "_Raw" and will
1388   // be used for casts and other things that need a variant of the class that
1389   // is initialized purely from an operation.
1390   py::object parentMetaclass =
1391       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1392   py::dict attributes;
1393   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1394   // now.
1395   //   auto opViewType = py::type::of<PyOpView>();
1396   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1397   attributes["__init__"] = opViewType.attr("__init__");
1398   py::str origName = userClass.attr("__name__");
1399   py::str newName = py::str("_") + origName;
1400   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1401 }
1402 
1403 //------------------------------------------------------------------------------
1404 // PyInsertionPoint.
1405 //------------------------------------------------------------------------------
1406 
1407 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1408 
1409 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1410     : refOperation(beforeOperationBase.getOperation().getRef()),
1411       block((*refOperation)->getBlock()) {}
1412 
1413 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1414   PyOperation &operation = operationBase.getOperation();
1415   if (operation.isAttached())
1416     throw SetPyError(PyExc_ValueError,
1417                      "Attempt to insert operation that is already attached");
1418   block.getParentOperation()->checkValid();
1419   MlirOperation beforeOp = {nullptr};
1420   if (refOperation) {
1421     // Insert before operation.
1422     (*refOperation)->checkValid();
1423     beforeOp = (*refOperation)->get();
1424   } else {
1425     // Insert at end (before null) is only valid if the block does not
1426     // already end in a known terminator (violating this will cause assertion
1427     // failures later).
1428     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1429       throw py::index_error("Cannot insert operation at the end of a block "
1430                             "that already has a terminator. Did you mean to "
1431                             "use 'InsertionPoint.at_block_terminator(block)' "
1432                             "versus 'InsertionPoint(block)'?");
1433     }
1434   }
1435   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1436   operation.setAttached();
1437 }
1438 
1439 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1440   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1441   if (mlirOperationIsNull(firstOp)) {
1442     // Just insert at end.
1443     return PyInsertionPoint(block);
1444   }
1445 
1446   // Insert before first op.
1447   PyOperationRef firstOpRef = PyOperation::forOperation(
1448       block.getParentOperation()->getContext(), firstOp);
1449   return PyInsertionPoint{block, std::move(firstOpRef)};
1450 }
1451 
1452 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1453   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1454   if (mlirOperationIsNull(terminator))
1455     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1456   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1457       block.getParentOperation()->getContext(), terminator);
1458   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1459 }
1460 
1461 py::object PyInsertionPoint::contextEnter() {
1462   return PyThreadContextEntry::pushInsertionPoint(*this);
1463 }
1464 
1465 void PyInsertionPoint::contextExit(pybind11::object excType,
1466                                    pybind11::object excVal,
1467                                    pybind11::object excTb) {
1468   PyThreadContextEntry::popInsertionPoint(*this);
1469 }
1470 
1471 //------------------------------------------------------------------------------
1472 // PyAttribute.
1473 //------------------------------------------------------------------------------
1474 
1475 bool PyAttribute::operator==(const PyAttribute &other) {
1476   return mlirAttributeEqual(attr, other.attr);
1477 }
1478 
1479 py::object PyAttribute::getCapsule() {
1480   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1481 }
1482 
1483 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1484   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1485   if (mlirAttributeIsNull(rawAttr))
1486     throw py::error_already_set();
1487   return PyAttribute(
1488       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1489 }
1490 
1491 //------------------------------------------------------------------------------
1492 // PyNamedAttribute.
1493 //------------------------------------------------------------------------------
1494 
1495 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1496     : ownedName(new std::string(std::move(ownedName))) {
1497   namedAttr = mlirNamedAttributeGet(
1498       mlirIdentifierGet(mlirAttributeGetContext(attr),
1499                         toMlirStringRef(*this->ownedName)),
1500       attr);
1501 }
1502 
1503 //------------------------------------------------------------------------------
1504 // PyType.
1505 //------------------------------------------------------------------------------
1506 
1507 bool PyType::operator==(const PyType &other) {
1508   return mlirTypeEqual(type, other.type);
1509 }
1510 
1511 py::object PyType::getCapsule() {
1512   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1513 }
1514 
1515 PyType PyType::createFromCapsule(py::object capsule) {
1516   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1517   if (mlirTypeIsNull(rawType))
1518     throw py::error_already_set();
1519   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1520                 rawType);
1521 }
1522 
1523 //------------------------------------------------------------------------------
1524 // PyValue and subclases.
1525 //------------------------------------------------------------------------------
1526 
1527 pybind11::object PyValue::getCapsule() {
1528   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1529 }
1530 
1531 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1532   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1533   if (mlirValueIsNull(value))
1534     throw py::error_already_set();
1535   MlirOperation owner;
1536   if (mlirValueIsAOpResult(value))
1537     owner = mlirOpResultGetOwner(value);
1538   if (mlirValueIsABlockArgument(value))
1539     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1540   if (mlirOperationIsNull(owner))
1541     throw py::error_already_set();
1542   MlirContext ctx = mlirOperationGetContext(owner);
1543   PyOperationRef ownerRef =
1544       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1545   return PyValue(ownerRef, value);
1546 }
1547 
1548 //------------------------------------------------------------------------------
1549 // PySymbolTable.
1550 //------------------------------------------------------------------------------
1551 
1552 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1553     : operation(operation.getOperation().getRef()) {
1554   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1555   if (mlirSymbolTableIsNull(symbolTable)) {
1556     throw py::cast_error("Operation is not a Symbol Table.");
1557   }
1558 }
1559 
1560 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1561   operation->checkValid();
1562   MlirOperation symbol = mlirSymbolTableLookup(
1563       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1564   if (mlirOperationIsNull(symbol))
1565     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1566 
1567   return PyOperation::forOperation(operation->getContext(), symbol,
1568                                    operation.getObject())
1569       ->createOpView();
1570 }
1571 
1572 void PySymbolTable::erase(PyOperationBase &symbol) {
1573   operation->checkValid();
1574   symbol.getOperation().checkValid();
1575   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1576   // The operation is also erased, so we must invalidate it. There may be Python
1577   // references to this operation so we don't want to delete it from the list of
1578   // live operations here.
1579   symbol.getOperation().valid = false;
1580 }
1581 
1582 void PySymbolTable::dunderDel(const std::string &name) {
1583   py::object operation = dunderGetItem(name);
1584   erase(py::cast<PyOperationBase &>(operation));
1585 }
1586 
1587 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1588   operation->checkValid();
1589   symbol.getOperation().checkValid();
1590   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1591       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1592   if (mlirAttributeIsNull(symbolAttr))
1593     throw py::value_error("Expected operation to have a symbol name.");
1594   return PyAttribute(
1595       symbol.getOperation().getContext(),
1596       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1597 }
1598 
1599 namespace {
1600 /// CRTP base class for Python MLIR values that subclass Value and should be
1601 /// castable from it. The value hierarchy is one level deep and is not supposed
1602 /// to accommodate other levels unless core MLIR changes.
1603 template <typename DerivedTy>
1604 class PyConcreteValue : public PyValue {
1605 public:
1606   // Derived classes must define statics for:
1607   //   IsAFunctionTy isaFunction
1608   //   const char *pyClassName
1609   // and redefine bindDerived.
1610   using ClassTy = py::class_<DerivedTy, PyValue>;
1611   using IsAFunctionTy = bool (*)(MlirValue);
1612 
1613   PyConcreteValue() = default;
1614   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1615       : PyValue(operationRef, value) {}
1616   PyConcreteValue(PyValue &orig)
1617       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1618 
1619   /// Attempts to cast the original value to the derived type and throws on
1620   /// type mismatches.
1621   static MlirValue castFrom(PyValue &orig) {
1622     if (!DerivedTy::isaFunction(orig.get())) {
1623       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1624       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1625                                              DerivedTy::pyClassName +
1626                                              " (from " + origRepr + ")");
1627     }
1628     return orig.get();
1629   }
1630 
1631   /// Binds the Python module objects to functions of this class.
1632   static void bind(py::module &m) {
1633     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1634     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1635     cls.def_static("isinstance", [](PyValue &otherValue) -> bool {
1636       return DerivedTy::isaFunction(otherValue);
1637     });
1638     DerivedTy::bindDerived(cls);
1639   }
1640 
1641   /// Implemented by derived classes to add methods to the Python subclass.
1642   static void bindDerived(ClassTy &m) {}
1643 };
1644 
1645 /// Python wrapper for MlirBlockArgument.
1646 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1647 public:
1648   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1649   static constexpr const char *pyClassName = "BlockArgument";
1650   using PyConcreteValue::PyConcreteValue;
1651 
1652   static void bindDerived(ClassTy &c) {
1653     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1654       return PyBlock(self.getParentOperation(),
1655                      mlirBlockArgumentGetOwner(self.get()));
1656     });
1657     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1658       return mlirBlockArgumentGetArgNumber(self.get());
1659     });
1660     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1661       return mlirBlockArgumentSetType(self.get(), type);
1662     });
1663   }
1664 };
1665 
1666 /// Python wrapper for MlirOpResult.
1667 class PyOpResult : public PyConcreteValue<PyOpResult> {
1668 public:
1669   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1670   static constexpr const char *pyClassName = "OpResult";
1671   using PyConcreteValue::PyConcreteValue;
1672 
1673   static void bindDerived(ClassTy &c) {
1674     c.def_property_readonly("owner", [](PyOpResult &self) {
1675       assert(
1676           mlirOperationEqual(self.getParentOperation()->get(),
1677                              mlirOpResultGetOwner(self.get())) &&
1678           "expected the owner of the value in Python to match that in the IR");
1679       return self.getParentOperation().getObject();
1680     });
1681     c.def_property_readonly("result_number", [](PyOpResult &self) {
1682       return mlirOpResultGetResultNumber(self.get());
1683     });
1684   }
1685 };
1686 
1687 /// Returns the list of types of the values held by container.
1688 template <typename Container>
1689 static std::vector<PyType> getValueTypes(Container &container,
1690                                          PyMlirContextRef &context) {
1691   std::vector<PyType> result;
1692   result.reserve(container.getNumElements());
1693   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1694     result.push_back(
1695         PyType(context, mlirValueGetType(container.getElement(i).get())));
1696   }
1697   return result;
1698 }
1699 
1700 /// A list of block arguments. Internally, these are stored as consecutive
1701 /// elements, random access is cheap. The argument list is associated with the
1702 /// operation that contains the block (detached blocks are not allowed in
1703 /// Python bindings) and extends its lifetime.
1704 class PyBlockArgumentList
1705     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1706 public:
1707   static constexpr const char *pyClassName = "BlockArgumentList";
1708 
1709   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1710                       intptr_t startIndex = 0, intptr_t length = -1,
1711                       intptr_t step = 1)
1712       : Sliceable(startIndex,
1713                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1714                   step),
1715         operation(std::move(operation)), block(block) {}
1716 
1717   /// Returns the number of arguments in the list.
1718   intptr_t getNumElements() {
1719     operation->checkValid();
1720     return mlirBlockGetNumArguments(block);
1721   }
1722 
1723   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1724   PyBlockArgument getElement(intptr_t pos) {
1725     MlirValue argument = mlirBlockGetArgument(block, pos);
1726     return PyBlockArgument(operation, argument);
1727   }
1728 
1729   /// Returns a sublist of this list.
1730   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1731                             intptr_t step) {
1732     return PyBlockArgumentList(operation, block, startIndex, length, step);
1733   }
1734 
1735   static void bindDerived(ClassTy &c) {
1736     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1737       return getValueTypes(self, self.operation->getContext());
1738     });
1739   }
1740 
1741 private:
1742   PyOperationRef operation;
1743   MlirBlock block;
1744 };
1745 
1746 /// A list of operation operands. Internally, these are stored as consecutive
1747 /// elements, random access is cheap. The result list is associated with the
1748 /// operation whose results these are, and extends the lifetime of this
1749 /// operation.
1750 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1751 public:
1752   static constexpr const char *pyClassName = "OpOperandList";
1753 
1754   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1755                   intptr_t length = -1, intptr_t step = 1)
1756       : Sliceable(startIndex,
1757                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1758                                : length,
1759                   step),
1760         operation(operation) {}
1761 
1762   intptr_t getNumElements() {
1763     operation->checkValid();
1764     return mlirOperationGetNumOperands(operation->get());
1765   }
1766 
1767   PyValue getElement(intptr_t pos) {
1768     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1769     MlirOperation owner;
1770     if (mlirValueIsAOpResult(operand))
1771       owner = mlirOpResultGetOwner(operand);
1772     else if (mlirValueIsABlockArgument(operand))
1773       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1774     else
1775       assert(false && "Value must be an block arg or op result.");
1776     PyOperationRef pyOwner =
1777         PyOperation::forOperation(operation->getContext(), owner);
1778     return PyValue(pyOwner, operand);
1779   }
1780 
1781   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1782     return PyOpOperandList(operation, startIndex, length, step);
1783   }
1784 
1785   void dunderSetItem(intptr_t index, PyValue value) {
1786     index = wrapIndex(index);
1787     mlirOperationSetOperand(operation->get(), index, value.get());
1788   }
1789 
1790   static void bindDerived(ClassTy &c) {
1791     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1792   }
1793 
1794 private:
1795   PyOperationRef operation;
1796 };
1797 
1798 /// A list of operation results. Internally, these are stored as consecutive
1799 /// elements, random access is cheap. The result list is associated with the
1800 /// operation whose results these are, and extends the lifetime of this
1801 /// operation.
1802 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1803 public:
1804   static constexpr const char *pyClassName = "OpResultList";
1805 
1806   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1807                  intptr_t length = -1, intptr_t step = 1)
1808       : Sliceable(startIndex,
1809                   length == -1 ? mlirOperationGetNumResults(operation->get())
1810                                : length,
1811                   step),
1812         operation(operation) {}
1813 
1814   intptr_t getNumElements() {
1815     operation->checkValid();
1816     return mlirOperationGetNumResults(operation->get());
1817   }
1818 
1819   PyOpResult getElement(intptr_t index) {
1820     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1821     return PyOpResult(value);
1822   }
1823 
1824   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1825     return PyOpResultList(operation, startIndex, length, step);
1826   }
1827 
1828   static void bindDerived(ClassTy &c) {
1829     c.def_property_readonly("types", [](PyOpResultList &self) {
1830       return getValueTypes(self, self.operation->getContext());
1831     });
1832   }
1833 
1834 private:
1835   PyOperationRef operation;
1836 };
1837 
1838 /// A list of operation attributes. Can be indexed by name, producing
1839 /// attributes, or by index, producing named attributes.
1840 class PyOpAttributeMap {
1841 public:
1842   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1843 
1844   PyAttribute dunderGetItemNamed(const std::string &name) {
1845     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1846                                                          toMlirStringRef(name));
1847     if (mlirAttributeIsNull(attr)) {
1848       throw SetPyError(PyExc_KeyError,
1849                        "attempt to access a non-existent attribute");
1850     }
1851     return PyAttribute(operation->getContext(), attr);
1852   }
1853 
1854   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1855     if (index < 0 || index >= dunderLen()) {
1856       throw SetPyError(PyExc_IndexError,
1857                        "attempt to access out of bounds attribute");
1858     }
1859     MlirNamedAttribute namedAttr =
1860         mlirOperationGetAttribute(operation->get(), index);
1861     return PyNamedAttribute(
1862         namedAttr.attribute,
1863         std::string(mlirIdentifierStr(namedAttr.name).data,
1864                     mlirIdentifierStr(namedAttr.name).length));
1865   }
1866 
1867   void dunderSetItem(const std::string &name, PyAttribute attr) {
1868     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1869                                     attr);
1870   }
1871 
1872   void dunderDelItem(const std::string &name) {
1873     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1874                                                      toMlirStringRef(name));
1875     if (!removed)
1876       throw SetPyError(PyExc_KeyError,
1877                        "attempt to delete a non-existent attribute");
1878   }
1879 
1880   intptr_t dunderLen() {
1881     return mlirOperationGetNumAttributes(operation->get());
1882   }
1883 
1884   bool dunderContains(const std::string &name) {
1885     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1886         operation->get(), toMlirStringRef(name)));
1887   }
1888 
1889   static void bind(py::module &m) {
1890     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
1891         .def("__contains__", &PyOpAttributeMap::dunderContains)
1892         .def("__len__", &PyOpAttributeMap::dunderLen)
1893         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1894         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1895         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1896         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1897   }
1898 
1899 private:
1900   PyOperationRef operation;
1901 };
1902 
1903 } // end namespace
1904 
1905 //------------------------------------------------------------------------------
1906 // Populates the core exports of the 'ir' submodule.
1907 //------------------------------------------------------------------------------
1908 
1909 void mlir::python::populateIRCore(py::module &m) {
1910   //----------------------------------------------------------------------------
1911   // Mapping of MlirContext.
1912   //----------------------------------------------------------------------------
1913   py::class_<PyMlirContext>(m, "Context", py::module_local())
1914       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1915       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1916       .def("_get_context_again",
1917            [](PyMlirContext &self) {
1918              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1919              return ref.releaseObject();
1920            })
1921       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1922       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1923       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1924                              &PyMlirContext::getCapsule)
1925       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1926       .def("__enter__", &PyMlirContext::contextEnter)
1927       .def("__exit__", &PyMlirContext::contextExit)
1928       .def_property_readonly_static(
1929           "current",
1930           [](py::object & /*class*/) {
1931             auto *context = PyThreadContextEntry::getDefaultContext();
1932             if (!context)
1933               throw SetPyError(PyExc_ValueError, "No current Context");
1934             return context;
1935           },
1936           "Gets the Context bound to the current thread or raises ValueError")
1937       .def_property_readonly(
1938           "dialects",
1939           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1940           "Gets a container for accessing dialects by name")
1941       .def_property_readonly(
1942           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1943           "Alias for 'dialect'")
1944       .def(
1945           "get_dialect_descriptor",
1946           [=](PyMlirContext &self, std::string &name) {
1947             MlirDialect dialect = mlirContextGetOrLoadDialect(
1948                 self.get(), {name.data(), name.size()});
1949             if (mlirDialectIsNull(dialect)) {
1950               throw SetPyError(PyExc_ValueError,
1951                                Twine("Dialect '") + name + "' not found");
1952             }
1953             return PyDialectDescriptor(self.getRef(), dialect);
1954           },
1955           "Gets or loads a dialect by name, returning its descriptor object")
1956       .def_property(
1957           "allow_unregistered_dialects",
1958           [](PyMlirContext &self) -> bool {
1959             return mlirContextGetAllowUnregisteredDialects(self.get());
1960           },
1961           [](PyMlirContext &self, bool value) {
1962             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1963           })
1964       .def("enable_multithreading",
1965            [](PyMlirContext &self, bool enable) {
1966              mlirContextEnableMultithreading(self.get(), enable);
1967            })
1968       .def("is_registered_operation",
1969            [](PyMlirContext &self, std::string &name) {
1970              return mlirContextIsRegisteredOperation(
1971                  self.get(), MlirStringRef{name.data(), name.size()});
1972            });
1973 
1974   //----------------------------------------------------------------------------
1975   // Mapping of PyDialectDescriptor
1976   //----------------------------------------------------------------------------
1977   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
1978       .def_property_readonly("namespace",
1979                              [](PyDialectDescriptor &self) {
1980                                MlirStringRef ns =
1981                                    mlirDialectGetNamespace(self.get());
1982                                return py::str(ns.data, ns.length);
1983                              })
1984       .def("__repr__", [](PyDialectDescriptor &self) {
1985         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1986         std::string repr("<DialectDescriptor ");
1987         repr.append(ns.data, ns.length);
1988         repr.append(">");
1989         return repr;
1990       });
1991 
1992   //----------------------------------------------------------------------------
1993   // Mapping of PyDialects
1994   //----------------------------------------------------------------------------
1995   py::class_<PyDialects>(m, "Dialects", py::module_local())
1996       .def("__getitem__",
1997            [=](PyDialects &self, std::string keyName) {
1998              MlirDialect dialect =
1999                  self.getDialectForKey(keyName, /*attrError=*/false);
2000              py::object descriptor =
2001                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2002              return createCustomDialectWrapper(keyName, std::move(descriptor));
2003            })
2004       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2005         MlirDialect dialect =
2006             self.getDialectForKey(attrName, /*attrError=*/true);
2007         py::object descriptor =
2008             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2009         return createCustomDialectWrapper(attrName, std::move(descriptor));
2010       });
2011 
2012   //----------------------------------------------------------------------------
2013   // Mapping of PyDialect
2014   //----------------------------------------------------------------------------
2015   py::class_<PyDialect>(m, "Dialect", py::module_local())
2016       .def(py::init<py::object>(), "descriptor")
2017       .def_property_readonly(
2018           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2019       .def("__repr__", [](py::object self) {
2020         auto clazz = self.attr("__class__");
2021         return py::str("<Dialect ") +
2022                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2023                clazz.attr("__module__") + py::str(".") +
2024                clazz.attr("__name__") + py::str(")>");
2025       });
2026 
2027   //----------------------------------------------------------------------------
2028   // Mapping of Location
2029   //----------------------------------------------------------------------------
2030   py::class_<PyLocation>(m, "Location", py::module_local())
2031       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2032       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2033       .def("__enter__", &PyLocation::contextEnter)
2034       .def("__exit__", &PyLocation::contextExit)
2035       .def("__eq__",
2036            [](PyLocation &self, PyLocation &other) -> bool {
2037              return mlirLocationEqual(self, other);
2038            })
2039       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2040       .def_property_readonly_static(
2041           "current",
2042           [](py::object & /*class*/) {
2043             auto *loc = PyThreadContextEntry::getDefaultLocation();
2044             if (!loc)
2045               throw SetPyError(PyExc_ValueError, "No current Location");
2046             return loc;
2047           },
2048           "Gets the Location bound to the current thread or raises ValueError")
2049       .def_static(
2050           "unknown",
2051           [](DefaultingPyMlirContext context) {
2052             return PyLocation(context->getRef(),
2053                               mlirLocationUnknownGet(context->get()));
2054           },
2055           py::arg("context") = py::none(),
2056           "Gets a Location representing an unknown location")
2057       .def_static(
2058           "callsite",
2059           [](PyLocation callee, const std::vector<PyLocation> &frames,
2060              DefaultingPyMlirContext context) {
2061             if (frames.empty())
2062               throw py::value_error("No caller frames provided");
2063             MlirLocation caller = frames.back().get();
2064             for (const PyLocation &frame :
2065                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2066               caller = mlirLocationCallSiteGet(frame.get(), caller);
2067             return PyLocation(context->getRef(),
2068                               mlirLocationCallSiteGet(callee.get(), caller));
2069           },
2070           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2071           kContextGetCallSiteLocationDocstring)
2072       .def_static(
2073           "file",
2074           [](std::string filename, int line, int col,
2075              DefaultingPyMlirContext context) {
2076             return PyLocation(
2077                 context->getRef(),
2078                 mlirLocationFileLineColGet(
2079                     context->get(), toMlirStringRef(filename), line, col));
2080           },
2081           py::arg("filename"), py::arg("line"), py::arg("col"),
2082           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2083       .def_static(
2084           "name",
2085           [](std::string name, llvm::Optional<PyLocation> childLoc,
2086              DefaultingPyMlirContext context) {
2087             return PyLocation(
2088                 context->getRef(),
2089                 mlirLocationNameGet(
2090                     context->get(), toMlirStringRef(name),
2091                     childLoc ? childLoc->get()
2092                              : mlirLocationUnknownGet(context->get())));
2093           },
2094           py::arg("name"), py::arg("childLoc") = py::none(),
2095           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2096       .def_property_readonly(
2097           "context",
2098           [](PyLocation &self) { return self.getContext().getObject(); },
2099           "Context that owns the Location")
2100       .def("__repr__", [](PyLocation &self) {
2101         PyPrintAccumulator printAccum;
2102         mlirLocationPrint(self, printAccum.getCallback(),
2103                           printAccum.getUserData());
2104         return printAccum.join();
2105       });
2106 
2107   //----------------------------------------------------------------------------
2108   // Mapping of Module
2109   //----------------------------------------------------------------------------
2110   py::class_<PyModule>(m, "Module", py::module_local())
2111       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2112       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2113       .def_static(
2114           "parse",
2115           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2116             MlirModule module = mlirModuleCreateParse(
2117                 context->get(), toMlirStringRef(moduleAsm));
2118             // TODO: Rework error reporting once diagnostic engine is exposed
2119             // in C API.
2120             if (mlirModuleIsNull(module)) {
2121               throw SetPyError(
2122                   PyExc_ValueError,
2123                   "Unable to parse module assembly (see diagnostics)");
2124             }
2125             return PyModule::forModule(module).releaseObject();
2126           },
2127           py::arg("asm"), py::arg("context") = py::none(),
2128           kModuleParseDocstring)
2129       .def_static(
2130           "create",
2131           [](DefaultingPyLocation loc) {
2132             MlirModule module = mlirModuleCreateEmpty(loc);
2133             return PyModule::forModule(module).releaseObject();
2134           },
2135           py::arg("loc") = py::none(), "Creates an empty module")
2136       .def_property_readonly(
2137           "context",
2138           [](PyModule &self) { return self.getContext().getObject(); },
2139           "Context that created the Module")
2140       .def_property_readonly(
2141           "operation",
2142           [](PyModule &self) {
2143             return PyOperation::forOperation(self.getContext(),
2144                                              mlirModuleGetOperation(self.get()),
2145                                              self.getRef().releaseObject())
2146                 .releaseObject();
2147           },
2148           "Accesses the module as an operation")
2149       .def_property_readonly(
2150           "body",
2151           [](PyModule &self) {
2152             PyOperationRef module_op = PyOperation::forOperation(
2153                 self.getContext(), mlirModuleGetOperation(self.get()),
2154                 self.getRef().releaseObject());
2155             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2156             return returnBlock;
2157           },
2158           "Return the block for this module")
2159       .def(
2160           "dump",
2161           [](PyModule &self) {
2162             mlirOperationDump(mlirModuleGetOperation(self.get()));
2163           },
2164           kDumpDocstring)
2165       .def(
2166           "__str__",
2167           [](py::object self) {
2168             // Defer to the operation's __str__.
2169             return self.attr("operation").attr("__str__")();
2170           },
2171           kOperationStrDunderDocstring);
2172 
2173   //----------------------------------------------------------------------------
2174   // Mapping of Operation.
2175   //----------------------------------------------------------------------------
2176   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2177       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2178                              [](PyOperationBase &self) {
2179                                return self.getOperation().getCapsule();
2180                              })
2181       .def("__eq__",
2182            [](PyOperationBase &self, PyOperationBase &other) {
2183              return &self.getOperation() == &other.getOperation();
2184            })
2185       .def("__eq__",
2186            [](PyOperationBase &self, py::object other) { return false; })
2187       .def("__hash__",
2188            [](PyOperationBase &self) {
2189              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2190            })
2191       .def_property_readonly("attributes",
2192                              [](PyOperationBase &self) {
2193                                return PyOpAttributeMap(
2194                                    self.getOperation().getRef());
2195                              })
2196       .def_property_readonly("operands",
2197                              [](PyOperationBase &self) {
2198                                return PyOpOperandList(
2199                                    self.getOperation().getRef());
2200                              })
2201       .def_property_readonly("regions",
2202                              [](PyOperationBase &self) {
2203                                return PyRegionList(
2204                                    self.getOperation().getRef());
2205                              })
2206       .def_property_readonly(
2207           "results",
2208           [](PyOperationBase &self) {
2209             return PyOpResultList(self.getOperation().getRef());
2210           },
2211           "Returns the list of Operation results.")
2212       .def_property_readonly(
2213           "result",
2214           [](PyOperationBase &self) {
2215             auto &operation = self.getOperation();
2216             auto numResults = mlirOperationGetNumResults(operation);
2217             if (numResults != 1) {
2218               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2219               throw SetPyError(
2220                   PyExc_ValueError,
2221                   Twine("Cannot call .result on operation ") +
2222                       StringRef(name.data, name.length) + " which has " +
2223                       Twine(numResults) +
2224                       " results (it is only valid for operations with a "
2225                       "single result)");
2226             }
2227             return PyOpResult(operation.getRef(),
2228                               mlirOperationGetResult(operation, 0));
2229           },
2230           "Shortcut to get an op result if it has only one (throws an error "
2231           "otherwise).")
2232       .def_property_readonly(
2233           "location",
2234           [](PyOperationBase &self) {
2235             PyOperation &operation = self.getOperation();
2236             return PyLocation(operation.getContext(),
2237                               mlirOperationGetLocation(operation.get()));
2238           },
2239           "Returns the source location the operation was defined or derived "
2240           "from.")
2241       .def(
2242           "__str__",
2243           [](PyOperationBase &self) {
2244             return self.getAsm(/*binary=*/false,
2245                                /*largeElementsLimit=*/llvm::None,
2246                                /*enableDebugInfo=*/false,
2247                                /*prettyDebugInfo=*/false,
2248                                /*printGenericOpForm=*/false,
2249                                /*useLocalScope=*/false,
2250                                /*assumeVerified=*/false);
2251           },
2252           "Returns the assembly form of the operation.")
2253       .def("print", &PyOperationBase::print,
2254            // Careful: Lots of arguments must match up with print method.
2255            py::arg("file") = py::none(), py::arg("binary") = false,
2256            py::arg("large_elements_limit") = py::none(),
2257            py::arg("enable_debug_info") = false,
2258            py::arg("pretty_debug_info") = false,
2259            py::arg("print_generic_op_form") = false,
2260            py::arg("use_local_scope") = false,
2261            py::arg("assume_verified") = false, kOperationPrintDocstring)
2262       .def("get_asm", &PyOperationBase::getAsm,
2263            // Careful: Lots of arguments must match up with get_asm method.
2264            py::arg("binary") = false,
2265            py::arg("large_elements_limit") = py::none(),
2266            py::arg("enable_debug_info") = false,
2267            py::arg("pretty_debug_info") = false,
2268            py::arg("print_generic_op_form") = false,
2269            py::arg("use_local_scope") = false,
2270            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2271       .def(
2272           "verify",
2273           [](PyOperationBase &self) {
2274             return mlirOperationVerify(self.getOperation());
2275           },
2276           "Verify the operation and return true if it passes, false if it "
2277           "fails.")
2278       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2279            "Puts self immediately after the other operation in its parent "
2280            "block.")
2281       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2282            "Puts self immediately before the other operation in its parent "
2283            "block.")
2284       .def(
2285           "detach_from_parent",
2286           [](PyOperationBase &self) {
2287             PyOperation &operation = self.getOperation();
2288             operation.checkValid();
2289             if (!operation.isAttached())
2290               throw py::value_error("Detached operation has no parent.");
2291 
2292             operation.detachFromParent();
2293             return operation.createOpView();
2294           },
2295           "Detaches the operation from its parent block.");
2296 
2297   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2298       .def_static("create", &PyOperation::create, py::arg("name"),
2299                   py::arg("results") = py::none(),
2300                   py::arg("operands") = py::none(),
2301                   py::arg("attributes") = py::none(),
2302                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2303                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2304                   kOperationCreateDocstring)
2305       .def_property_readonly("parent",
2306                              [](PyOperation &self) -> py::object {
2307                                auto parent = self.getParentOperation();
2308                                if (parent)
2309                                  return parent->getObject();
2310                                return py::none();
2311                              })
2312       .def("erase", &PyOperation::erase)
2313       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2314                              &PyOperation::getCapsule)
2315       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2316       .def_property_readonly("name",
2317                              [](PyOperation &self) {
2318                                self.checkValid();
2319                                MlirOperation operation = self.get();
2320                                MlirStringRef name = mlirIdentifierStr(
2321                                    mlirOperationGetName(operation));
2322                                return py::str(name.data, name.length);
2323                              })
2324       .def_property_readonly(
2325           "context",
2326           [](PyOperation &self) {
2327             self.checkValid();
2328             return self.getContext().getObject();
2329           },
2330           "Context that owns the Operation")
2331       .def_property_readonly("opview", &PyOperation::createOpView);
2332 
2333   auto opViewClass =
2334       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2335           .def(py::init<py::object>())
2336           .def_property_readonly("operation", &PyOpView::getOperationObject)
2337           .def_property_readonly(
2338               "context",
2339               [](PyOpView &self) {
2340                 return self.getOperation().getContext().getObject();
2341               },
2342               "Context that owns the Operation")
2343           .def("__str__", [](PyOpView &self) {
2344             return py::str(self.getOperationObject());
2345           });
2346   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2347   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2348   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2349   opViewClass.attr("build_generic") = classmethod(
2350       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2351       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2352       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2353       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2354       "Builds a specific, generated OpView based on class level attributes.");
2355 
2356   //----------------------------------------------------------------------------
2357   // Mapping of PyRegion.
2358   //----------------------------------------------------------------------------
2359   py::class_<PyRegion>(m, "Region", py::module_local())
2360       .def_property_readonly(
2361           "blocks",
2362           [](PyRegion &self) {
2363             return PyBlockList(self.getParentOperation(), self.get());
2364           },
2365           "Returns a forward-optimized sequence of blocks.")
2366       .def_property_readonly(
2367           "owner",
2368           [](PyRegion &self) {
2369             return self.getParentOperation()->createOpView();
2370           },
2371           "Returns the operation owning this region.")
2372       .def(
2373           "__iter__",
2374           [](PyRegion &self) {
2375             self.checkValid();
2376             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2377             return PyBlockIterator(self.getParentOperation(), firstBlock);
2378           },
2379           "Iterates over blocks in the region.")
2380       .def("__eq__",
2381            [](PyRegion &self, PyRegion &other) {
2382              return self.get().ptr == other.get().ptr;
2383            })
2384       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2385 
2386   //----------------------------------------------------------------------------
2387   // Mapping of PyBlock.
2388   //----------------------------------------------------------------------------
2389   py::class_<PyBlock>(m, "Block", py::module_local())
2390       .def_property_readonly(
2391           "owner",
2392           [](PyBlock &self) {
2393             return self.getParentOperation()->createOpView();
2394           },
2395           "Returns the owning operation of this block.")
2396       .def_property_readonly(
2397           "region",
2398           [](PyBlock &self) {
2399             MlirRegion region = mlirBlockGetParentRegion(self.get());
2400             return PyRegion(self.getParentOperation(), region);
2401           },
2402           "Returns the owning region of this block.")
2403       .def_property_readonly(
2404           "arguments",
2405           [](PyBlock &self) {
2406             return PyBlockArgumentList(self.getParentOperation(), self.get());
2407           },
2408           "Returns a list of block arguments.")
2409       .def_property_readonly(
2410           "operations",
2411           [](PyBlock &self) {
2412             return PyOperationList(self.getParentOperation(), self.get());
2413           },
2414           "Returns a forward-optimized sequence of operations.")
2415       .def_static(
2416           "create_at_start",
2417           [](PyRegion &parent, py::list pyArgTypes) {
2418             parent.checkValid();
2419             llvm::SmallVector<MlirType, 4> argTypes;
2420             argTypes.reserve(pyArgTypes.size());
2421             for (auto &pyArg : pyArgTypes) {
2422               argTypes.push_back(pyArg.cast<PyType &>());
2423             }
2424 
2425             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2426             mlirRegionInsertOwnedBlock(parent, 0, block);
2427             return PyBlock(parent.getParentOperation(), block);
2428           },
2429           py::arg("parent"), py::arg("pyArgTypes") = py::list(),
2430           "Creates and returns a new Block at the beginning of the given "
2431           "region (with given argument types).")
2432       .def(
2433           "create_before",
2434           [](PyBlock &self, py::args pyArgTypes) {
2435             self.checkValid();
2436             llvm::SmallVector<MlirType, 4> argTypes;
2437             argTypes.reserve(pyArgTypes.size());
2438             for (auto &pyArg : pyArgTypes) {
2439               argTypes.push_back(pyArg.cast<PyType &>());
2440             }
2441 
2442             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2443             MlirRegion region = mlirBlockGetParentRegion(self.get());
2444             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2445             return PyBlock(self.getParentOperation(), block);
2446           },
2447           "Creates and returns a new Block before this block "
2448           "(with given argument types).")
2449       .def(
2450           "create_after",
2451           [](PyBlock &self, py::args pyArgTypes) {
2452             self.checkValid();
2453             llvm::SmallVector<MlirType, 4> argTypes;
2454             argTypes.reserve(pyArgTypes.size());
2455             for (auto &pyArg : pyArgTypes) {
2456               argTypes.push_back(pyArg.cast<PyType &>());
2457             }
2458 
2459             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2460             MlirRegion region = mlirBlockGetParentRegion(self.get());
2461             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2462             return PyBlock(self.getParentOperation(), block);
2463           },
2464           "Creates and returns a new Block after this block "
2465           "(with given argument types).")
2466       .def(
2467           "__iter__",
2468           [](PyBlock &self) {
2469             self.checkValid();
2470             MlirOperation firstOperation =
2471                 mlirBlockGetFirstOperation(self.get());
2472             return PyOperationIterator(self.getParentOperation(),
2473                                        firstOperation);
2474           },
2475           "Iterates over operations in the block.")
2476       .def("__eq__",
2477            [](PyBlock &self, PyBlock &other) {
2478              return self.get().ptr == other.get().ptr;
2479            })
2480       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2481       .def(
2482           "__str__",
2483           [](PyBlock &self) {
2484             self.checkValid();
2485             PyPrintAccumulator printAccum;
2486             mlirBlockPrint(self.get(), printAccum.getCallback(),
2487                            printAccum.getUserData());
2488             return printAccum.join();
2489           },
2490           "Returns the assembly form of the block.")
2491       .def(
2492           "append",
2493           [](PyBlock &self, PyOperationBase &operation) {
2494             if (operation.getOperation().isAttached())
2495               operation.getOperation().detachFromParent();
2496 
2497             MlirOperation mlirOperation = operation.getOperation().get();
2498             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2499             operation.getOperation().setAttached(
2500                 self.getParentOperation().getObject());
2501           },
2502           "Appends an operation to this block. If the operation is currently "
2503           "in another block, it will be moved.");
2504 
2505   //----------------------------------------------------------------------------
2506   // Mapping of PyInsertionPoint.
2507   //----------------------------------------------------------------------------
2508 
2509   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2510       .def(py::init<PyBlock &>(), py::arg("block"),
2511            "Inserts after the last operation but still inside the block.")
2512       .def("__enter__", &PyInsertionPoint::contextEnter)
2513       .def("__exit__", &PyInsertionPoint::contextExit)
2514       .def_property_readonly_static(
2515           "current",
2516           [](py::object & /*class*/) {
2517             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2518             if (!ip)
2519               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2520             return ip;
2521           },
2522           "Gets the InsertionPoint bound to the current thread or raises "
2523           "ValueError if none has been set")
2524       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2525            "Inserts before a referenced operation.")
2526       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2527                   py::arg("block"), "Inserts at the beginning of the block.")
2528       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2529                   py::arg("block"), "Inserts before the block terminator.")
2530       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2531            "Inserts an operation.")
2532       .def_property_readonly(
2533           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2534           "Returns the block that this InsertionPoint points to.");
2535 
2536   //----------------------------------------------------------------------------
2537   // Mapping of PyAttribute.
2538   //----------------------------------------------------------------------------
2539   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2540       // Delegate to the PyAttribute copy constructor, which will also lifetime
2541       // extend the backing context which owns the MlirAttribute.
2542       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2543            "Casts the passed attribute to the generic Attribute")
2544       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2545                              &PyAttribute::getCapsule)
2546       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2547       .def_static(
2548           "parse",
2549           [](std::string attrSpec, DefaultingPyMlirContext context) {
2550             MlirAttribute type = mlirAttributeParseGet(
2551                 context->get(), toMlirStringRef(attrSpec));
2552             // TODO: Rework error reporting once diagnostic engine is exposed
2553             // in C API.
2554             if (mlirAttributeIsNull(type)) {
2555               throw SetPyError(PyExc_ValueError,
2556                                Twine("Unable to parse attribute: '") +
2557                                    attrSpec + "'");
2558             }
2559             return PyAttribute(context->getRef(), type);
2560           },
2561           py::arg("asm"), py::arg("context") = py::none(),
2562           "Parses an attribute from an assembly form")
2563       .def_property_readonly(
2564           "context",
2565           [](PyAttribute &self) { return self.getContext().getObject(); },
2566           "Context that owns the Attribute")
2567       .def_property_readonly("type",
2568                              [](PyAttribute &self) {
2569                                return PyType(self.getContext()->getRef(),
2570                                              mlirAttributeGetType(self));
2571                              })
2572       .def(
2573           "get_named",
2574           [](PyAttribute &self, std::string name) {
2575             return PyNamedAttribute(self, std::move(name));
2576           },
2577           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2578       .def("__eq__",
2579            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2580       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2581       .def("__hash__",
2582            [](PyAttribute &self) {
2583              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2584            })
2585       .def(
2586           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2587           kDumpDocstring)
2588       .def(
2589           "__str__",
2590           [](PyAttribute &self) {
2591             PyPrintAccumulator printAccum;
2592             mlirAttributePrint(self, printAccum.getCallback(),
2593                                printAccum.getUserData());
2594             return printAccum.join();
2595           },
2596           "Returns the assembly form of the Attribute.")
2597       .def("__repr__", [](PyAttribute &self) {
2598         // Generally, assembly formats are not printed for __repr__ because
2599         // this can cause exceptionally long debug output and exceptions.
2600         // However, attribute values are generally considered useful and are
2601         // printed. This may need to be re-evaluated if debug dumps end up
2602         // being excessive.
2603         PyPrintAccumulator printAccum;
2604         printAccum.parts.append("Attribute(");
2605         mlirAttributePrint(self, printAccum.getCallback(),
2606                            printAccum.getUserData());
2607         printAccum.parts.append(")");
2608         return printAccum.join();
2609       });
2610 
2611   //----------------------------------------------------------------------------
2612   // Mapping of PyNamedAttribute
2613   //----------------------------------------------------------------------------
2614   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2615       .def("__repr__",
2616            [](PyNamedAttribute &self) {
2617              PyPrintAccumulator printAccum;
2618              printAccum.parts.append("NamedAttribute(");
2619              printAccum.parts.append(
2620                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2621                          mlirIdentifierStr(self.namedAttr.name).length));
2622              printAccum.parts.append("=");
2623              mlirAttributePrint(self.namedAttr.attribute,
2624                                 printAccum.getCallback(),
2625                                 printAccum.getUserData());
2626              printAccum.parts.append(")");
2627              return printAccum.join();
2628            })
2629       .def_property_readonly(
2630           "name",
2631           [](PyNamedAttribute &self) {
2632             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2633                            mlirIdentifierStr(self.namedAttr.name).length);
2634           },
2635           "The name of the NamedAttribute binding")
2636       .def_property_readonly(
2637           "attr",
2638           [](PyNamedAttribute &self) {
2639             // TODO: When named attribute is removed/refactored, also remove
2640             // this constructor (it does an inefficient table lookup).
2641             auto contextRef = PyMlirContext::forContext(
2642                 mlirAttributeGetContext(self.namedAttr.attribute));
2643             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2644           },
2645           py::keep_alive<0, 1>(),
2646           "The underlying generic attribute of the NamedAttribute binding");
2647 
2648   //----------------------------------------------------------------------------
2649   // Mapping of PyType.
2650   //----------------------------------------------------------------------------
2651   py::class_<PyType>(m, "Type", py::module_local())
2652       // Delegate to the PyType copy constructor, which will also lifetime
2653       // extend the backing context which owns the MlirType.
2654       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2655            "Casts the passed type to the generic Type")
2656       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2657       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2658       .def_static(
2659           "parse",
2660           [](std::string typeSpec, DefaultingPyMlirContext context) {
2661             MlirType type =
2662                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2663             // TODO: Rework error reporting once diagnostic engine is exposed
2664             // in C API.
2665             if (mlirTypeIsNull(type)) {
2666               throw SetPyError(PyExc_ValueError,
2667                                Twine("Unable to parse type: '") + typeSpec +
2668                                    "'");
2669             }
2670             return PyType(context->getRef(), type);
2671           },
2672           py::arg("asm"), py::arg("context") = py::none(),
2673           kContextParseTypeDocstring)
2674       .def_property_readonly(
2675           "context", [](PyType &self) { return self.getContext().getObject(); },
2676           "Context that owns the Type")
2677       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2678       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2679       .def("__hash__",
2680            [](PyType &self) {
2681              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2682            })
2683       .def(
2684           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2685       .def(
2686           "__str__",
2687           [](PyType &self) {
2688             PyPrintAccumulator printAccum;
2689             mlirTypePrint(self, printAccum.getCallback(),
2690                           printAccum.getUserData());
2691             return printAccum.join();
2692           },
2693           "Returns the assembly form of the type.")
2694       .def("__repr__", [](PyType &self) {
2695         // Generally, assembly formats are not printed for __repr__ because
2696         // this can cause exceptionally long debug output and exceptions.
2697         // However, types are an exception as they typically have compact
2698         // assembly forms and printing them is useful.
2699         PyPrintAccumulator printAccum;
2700         printAccum.parts.append("Type(");
2701         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2702         printAccum.parts.append(")");
2703         return printAccum.join();
2704       });
2705 
2706   //----------------------------------------------------------------------------
2707   // Mapping of Value.
2708   //----------------------------------------------------------------------------
2709   py::class_<PyValue>(m, "Value", py::module_local())
2710       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2711       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2712       .def_property_readonly(
2713           "context",
2714           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2715           "Context in which the value lives.")
2716       .def(
2717           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2718           kDumpDocstring)
2719       .def_property_readonly(
2720           "owner",
2721           [](PyValue &self) {
2722             assert(mlirOperationEqual(self.getParentOperation()->get(),
2723                                       mlirOpResultGetOwner(self.get())) &&
2724                    "expected the owner of the value in Python to match that in "
2725                    "the IR");
2726             return self.getParentOperation().getObject();
2727           })
2728       .def("__eq__",
2729            [](PyValue &self, PyValue &other) {
2730              return self.get().ptr == other.get().ptr;
2731            })
2732       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2733       .def("__hash__",
2734            [](PyValue &self) {
2735              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2736            })
2737       .def(
2738           "__str__",
2739           [](PyValue &self) {
2740             PyPrintAccumulator printAccum;
2741             printAccum.parts.append("Value(");
2742             mlirValuePrint(self.get(), printAccum.getCallback(),
2743                            printAccum.getUserData());
2744             printAccum.parts.append(")");
2745             return printAccum.join();
2746           },
2747           kValueDunderStrDocstring)
2748       .def_property_readonly("type", [](PyValue &self) {
2749         return PyType(self.getParentOperation()->getContext(),
2750                       mlirValueGetType(self.get()));
2751       });
2752   PyBlockArgument::bind(m);
2753   PyOpResult::bind(m);
2754 
2755   //----------------------------------------------------------------------------
2756   // Mapping of SymbolTable.
2757   //----------------------------------------------------------------------------
2758   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2759       .def(py::init<PyOperationBase &>())
2760       .def("__getitem__", &PySymbolTable::dunderGetItem)
2761       .def("insert", &PySymbolTable::insert)
2762       .def("erase", &PySymbolTable::erase)
2763       .def("__delitem__", &PySymbolTable::dunderDel)
2764       .def("__contains__", [](PySymbolTable &table, const std::string &name) {
2765         return !mlirOperationIsNull(mlirSymbolTableLookup(
2766             table, mlirStringRefCreate(name.data(), name.length())));
2767       });
2768 
2769   // Container bindings.
2770   PyBlockArgumentList::bind(m);
2771   PyBlockIterator::bind(m);
2772   PyBlockList::bind(m);
2773   PyOperationIterator::bind(m);
2774   PyOperationList::bind(m);
2775   PyOpAttributeMap::bind(m);
2776   PyOpOperandList::bind(m);
2777   PyOpResultList::bind(m);
2778   PyRegionIterator::bind(m);
2779   PyRegionList::bind(m);
2780 
2781   // Debug bindings.
2782   PyGlobalDebugFlag::bind(m);
2783 }
2784