1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 #include <utility>
25 
26 namespace py = pybind11;
27 using namespace mlir;
28 using namespace mlir::python;
29 
30 using llvm::SmallVector;
31 using llvm::StringRef;
32 using llvm::Twine;
33 
34 //------------------------------------------------------------------------------
35 // Docstrings (trivial, non-duplicated docstrings are included inline).
36 //------------------------------------------------------------------------------
37 
38 static const char kContextParseTypeDocstring[] =
39     R"(Parses the assembly form of a type.
40 
41 Returns a Type object or raises a ValueError if the type cannot be parsed.
42 
43 See also: https://mlir.llvm.org/docs/LangRef/#type-system
44 )";
45 
46 static const char kContextGetCallSiteLocationDocstring[] =
47     R"(Gets a Location representing a caller and callsite)";
48 
49 static const char kContextGetFileLocationDocstring[] =
50     R"(Gets a Location representing a file, line and column)";
51 
52 static const char kContextGetFusedLocationDocstring[] =
53     R"(Gets a Location representing a fused location with optional metadata)";
54 
55 static const char kContextGetNameLocationDocString[] =
56     R"(Gets a Location representing a named location with optional child location)";
57 
58 static const char kModuleParseDocstring[] =
59     R"(Parses a module's assembly format from a string.
60 
61 Returns a new MlirModule or raises a ValueError if the parsing fails.
62 
63 See also: https://mlir.llvm.org/docs/LangRef/
64 )";
65 
66 static const char kOperationCreateDocstring[] =
67     R"(Creates a new operation.
68 
69 Args:
70   name: Operation name (e.g. "dialect.operation").
71   results: Sequence of Type representing op result types.
72   attributes: Dict of str:Attribute.
73   successors: List of Block for the operation's successors.
74   regions: Number of regions to create.
75   location: A Location object (defaults to resolve from context manager).
76   ip: An InsertionPoint (defaults to resolve from context manager or set to
77     False to disable insertion, even with an insertion point set in the
78     context manager).
79 Returns:
80   A new "detached" Operation object. Detached operations can be added
81   to blocks, which causes them to become "attached."
82 )";
83 
84 static const char kOperationPrintDocstring[] =
85     R"(Prints the assembly form of the operation to a file like object.
86 
87 Args:
88   file: The file like object to write to. Defaults to sys.stdout.
89   binary: Whether to write bytes (True) or str (False). Defaults to False.
90   large_elements_limit: Whether to elide elements attributes above this
91     number of elements. Defaults to None (no limit).
92   enable_debug_info: Whether to print debug/location information. Defaults
93     to False.
94   pretty_debug_info: Whether to format debug information for easier reading
95     by a human (warning: the result is unparseable).
96   print_generic_op_form: Whether to print the generic assembly forms of all
97     ops. Defaults to False.
98   use_local_Scope: Whether to print in a way that is more optimized for
99     multi-threaded access but may not be consistent with how the overall
100     module prints.
101   assume_verified: By default, if not printing generic form, the verifier
102     will be run and if it fails, generic form will be printed with a comment
103     about failed verification. While a reasonable default for interactive use,
104     for systematic use, it is often better for the caller to verify explicitly
105     and report failures in a more robust fashion. Set this to True if doing this
106     in order to avoid running a redundant verification. If the IR is actually
107     invalid, behavior is undefined.
108 )";
109 
110 static const char kOperationGetAsmDocstring[] =
111     R"(Gets the assembly form of the operation with all options available.
112 
113 Args:
114   binary: Whether to return a bytes (True) or str (False) object. Defaults to
115     False.
116   ... others ...: See the print() method for common keyword arguments for
117     configuring the printout.
118 Returns:
119   Either a bytes or str object, depending on the setting of the 'binary'
120   argument.
121 )";
122 
123 static const char kOperationStrDunderDocstring[] =
124     R"(Gets the assembly form of the operation with default options.
125 
126 If more advanced control over the assembly formatting or I/O options is needed,
127 use the dedicated print or get_asm method, which supports keyword arguments to
128 customize behavior.
129 )";
130 
131 static const char kDumpDocstring[] =
132     R"(Dumps a debug representation of the object to stderr.)";
133 
134 static const char kAppendBlockDocstring[] =
135     R"(Appends a new block, with argument types as positional args.
136 
137 Returns:
138   The created block.
139 )";
140 
141 static const char kValueDunderStrDocstring[] =
142     R"(Returns the string form of the value.
143 
144 If the value is a block argument, this is the assembly form of its type and the
145 position in the argument list. If the value is an operation result, this is
146 equivalent to printing the operation that produced it.
147 )";
148 
149 //------------------------------------------------------------------------------
150 // Utilities.
151 //------------------------------------------------------------------------------
152 
153 /// Helper for creating an @classmethod.
154 template <class Func, typename... Args>
155 py::object classmethod(Func f, Args... args) {
156   py::object cf = py::cpp_function(f, args...);
157   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
158 }
159 
160 static py::object
161 createCustomDialectWrapper(const std::string &dialectNamespace,
162                            py::object dialectDescriptor) {
163   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
164   if (!dialectClass) {
165     // Use the base class.
166     return py::cast(PyDialect(std::move(dialectDescriptor)));
167   }
168 
169   // Create the custom implementation.
170   return (*dialectClass)(std::move(dialectDescriptor));
171 }
172 
173 static MlirStringRef toMlirStringRef(const std::string &s) {
174   return mlirStringRefCreate(s.data(), s.size());
175 }
176 
177 /// Wrapper for the global LLVM debugging flag.
178 struct PyGlobalDebugFlag {
179   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
180 
181   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
182 
183   static void bind(py::module &m) {
184     // Debug flags.
185     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
186         .def_property_static("flag", &PyGlobalDebugFlag::get,
187                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
188   }
189 };
190 
191 //------------------------------------------------------------------------------
192 // Collections.
193 //------------------------------------------------------------------------------
194 
195 namespace {
196 
197 class PyRegionIterator {
198 public:
199   PyRegionIterator(PyOperationRef operation)
200       : operation(std::move(operation)) {}
201 
202   PyRegionIterator &dunderIter() { return *this; }
203 
204   PyRegion dunderNext() {
205     operation->checkValid();
206     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
207       throw py::stop_iteration();
208     }
209     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
210     return PyRegion(operation, region);
211   }
212 
213   static void bind(py::module &m) {
214     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
215         .def("__iter__", &PyRegionIterator::dunderIter)
216         .def("__next__", &PyRegionIterator::dunderNext);
217   }
218 
219 private:
220   PyOperationRef operation;
221   int nextIndex = 0;
222 };
223 
224 /// Regions of an op are fixed length and indexed numerically so are represented
225 /// with a sequence-like container.
226 class PyRegionList {
227 public:
228   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
229 
230   intptr_t dunderLen() {
231     operation->checkValid();
232     return mlirOperationGetNumRegions(operation->get());
233   }
234 
235   PyRegion dunderGetItem(intptr_t index) {
236     // dunderLen checks validity.
237     if (index < 0 || index >= dunderLen()) {
238       throw SetPyError(PyExc_IndexError,
239                        "attempt to access out of bounds region");
240     }
241     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
242     return PyRegion(operation, region);
243   }
244 
245   static void bind(py::module &m) {
246     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
247         .def("__len__", &PyRegionList::dunderLen)
248         .def("__getitem__", &PyRegionList::dunderGetItem);
249   }
250 
251 private:
252   PyOperationRef operation;
253 };
254 
255 class PyBlockIterator {
256 public:
257   PyBlockIterator(PyOperationRef operation, MlirBlock next)
258       : operation(std::move(operation)), next(next) {}
259 
260   PyBlockIterator &dunderIter() { return *this; }
261 
262   PyBlock dunderNext() {
263     operation->checkValid();
264     if (mlirBlockIsNull(next)) {
265       throw py::stop_iteration();
266     }
267 
268     PyBlock returnBlock(operation, next);
269     next = mlirBlockGetNextInRegion(next);
270     return returnBlock;
271   }
272 
273   static void bind(py::module &m) {
274     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
275         .def("__iter__", &PyBlockIterator::dunderIter)
276         .def("__next__", &PyBlockIterator::dunderNext);
277   }
278 
279 private:
280   PyOperationRef operation;
281   MlirBlock next;
282 };
283 
284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
285 /// we present them as a more full-featured list-like container but optimize
286 /// it for forward iteration. Blocks are always owned by a region.
287 class PyBlockList {
288 public:
289   PyBlockList(PyOperationRef operation, MlirRegion region)
290       : operation(std::move(operation)), region(region) {}
291 
292   PyBlockIterator dunderIter() {
293     operation->checkValid();
294     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
295   }
296 
297   intptr_t dunderLen() {
298     operation->checkValid();
299     intptr_t count = 0;
300     MlirBlock block = mlirRegionGetFirstBlock(region);
301     while (!mlirBlockIsNull(block)) {
302       count += 1;
303       block = mlirBlockGetNextInRegion(block);
304     }
305     return count;
306   }
307 
308   PyBlock dunderGetItem(intptr_t index) {
309     operation->checkValid();
310     if (index < 0) {
311       throw SetPyError(PyExc_IndexError,
312                        "attempt to access out of bounds block");
313     }
314     MlirBlock block = mlirRegionGetFirstBlock(region);
315     while (!mlirBlockIsNull(block)) {
316       if (index == 0) {
317         return PyBlock(operation, block);
318       }
319       block = mlirBlockGetNextInRegion(block);
320       index -= 1;
321     }
322     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
323   }
324 
325   PyBlock appendBlock(const py::args &pyArgTypes) {
326     operation->checkValid();
327     llvm::SmallVector<MlirType, 4> argTypes;
328     argTypes.reserve(pyArgTypes.size());
329     for (auto &pyArg : pyArgTypes) {
330       argTypes.push_back(pyArg.cast<PyType &>());
331     }
332 
333     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
334     mlirRegionAppendOwnedBlock(region, block);
335     return PyBlock(operation, block);
336   }
337 
338   static void bind(py::module &m) {
339     py::class_<PyBlockList>(m, "BlockList", py::module_local())
340         .def("__getitem__", &PyBlockList::dunderGetItem)
341         .def("__iter__", &PyBlockList::dunderIter)
342         .def("__len__", &PyBlockList::dunderLen)
343         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
344   }
345 
346 private:
347   PyOperationRef operation;
348   MlirRegion region;
349 };
350 
351 class PyOperationIterator {
352 public:
353   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
354       : parentOperation(std::move(parentOperation)), next(next) {}
355 
356   PyOperationIterator &dunderIter() { return *this; }
357 
358   py::object dunderNext() {
359     parentOperation->checkValid();
360     if (mlirOperationIsNull(next)) {
361       throw py::stop_iteration();
362     }
363 
364     PyOperationRef returnOperation =
365         PyOperation::forOperation(parentOperation->getContext(), next);
366     next = mlirOperationGetNextInBlock(next);
367     return returnOperation->createOpView();
368   }
369 
370   static void bind(py::module &m) {
371     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
372         .def("__iter__", &PyOperationIterator::dunderIter)
373         .def("__next__", &PyOperationIterator::dunderNext);
374   }
375 
376 private:
377   PyOperationRef parentOperation;
378   MlirOperation next;
379 };
380 
381 /// Operations are exposed by the C-API as a forward-only linked list. In
382 /// Python, we present them as a more full-featured list-like container but
383 /// optimize it for forward iteration. Iterable operations are always owned
384 /// by a block.
385 class PyOperationList {
386 public:
387   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
388       : parentOperation(std::move(parentOperation)), block(block) {}
389 
390   PyOperationIterator dunderIter() {
391     parentOperation->checkValid();
392     return PyOperationIterator(parentOperation,
393                                mlirBlockGetFirstOperation(block));
394   }
395 
396   intptr_t dunderLen() {
397     parentOperation->checkValid();
398     intptr_t count = 0;
399     MlirOperation childOp = mlirBlockGetFirstOperation(block);
400     while (!mlirOperationIsNull(childOp)) {
401       count += 1;
402       childOp = mlirOperationGetNextInBlock(childOp);
403     }
404     return count;
405   }
406 
407   py::object dunderGetItem(intptr_t index) {
408     parentOperation->checkValid();
409     if (index < 0) {
410       throw SetPyError(PyExc_IndexError,
411                        "attempt to access out of bounds operation");
412     }
413     MlirOperation childOp = mlirBlockGetFirstOperation(block);
414     while (!mlirOperationIsNull(childOp)) {
415       if (index == 0) {
416         return PyOperation::forOperation(parentOperation->getContext(), childOp)
417             ->createOpView();
418       }
419       childOp = mlirOperationGetNextInBlock(childOp);
420       index -= 1;
421     }
422     throw SetPyError(PyExc_IndexError,
423                      "attempt to access out of bounds operation");
424   }
425 
426   static void bind(py::module &m) {
427     py::class_<PyOperationList>(m, "OperationList", py::module_local())
428         .def("__getitem__", &PyOperationList::dunderGetItem)
429         .def("__iter__", &PyOperationList::dunderIter)
430         .def("__len__", &PyOperationList::dunderLen);
431   }
432 
433 private:
434   PyOperationRef parentOperation;
435   MlirBlock block;
436 };
437 
438 } // namespace
439 
440 //------------------------------------------------------------------------------
441 // PyMlirContext
442 //------------------------------------------------------------------------------
443 
444 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
445   py::gil_scoped_acquire acquire;
446   auto &liveContexts = getLiveContexts();
447   liveContexts[context.ptr] = this;
448 }
449 
450 PyMlirContext::~PyMlirContext() {
451   // Note that the only public way to construct an instance is via the
452   // forContext method, which always puts the associated handle into
453   // liveContexts.
454   py::gil_scoped_acquire acquire;
455   getLiveContexts().erase(context.ptr);
456   mlirContextDestroy(context);
457 }
458 
459 py::object PyMlirContext::getCapsule() {
460   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
461 }
462 
463 py::object PyMlirContext::createFromCapsule(py::object capsule) {
464   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
465   if (mlirContextIsNull(rawContext))
466     throw py::error_already_set();
467   return forContext(rawContext).releaseObject();
468 }
469 
470 PyMlirContext *PyMlirContext::createNewContextForInit() {
471   MlirContext context = mlirContextCreate();
472   mlirRegisterAllDialects(context);
473   return new PyMlirContext(context);
474 }
475 
476 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
477   py::gil_scoped_acquire acquire;
478   auto &liveContexts = getLiveContexts();
479   auto it = liveContexts.find(context.ptr);
480   if (it == liveContexts.end()) {
481     // Create.
482     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
483     py::object pyRef = py::cast(unownedContextWrapper);
484     assert(pyRef && "cast to py::object failed");
485     liveContexts[context.ptr] = unownedContextWrapper;
486     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
487   }
488   // Use existing.
489   py::object pyRef = py::cast(it->second);
490   return PyMlirContextRef(it->second, std::move(pyRef));
491 }
492 
493 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
494   static LiveContextMap liveContexts;
495   return liveContexts;
496 }
497 
498 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
499 
500 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
501 
502 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
503 
504 pybind11::object PyMlirContext::contextEnter() {
505   return PyThreadContextEntry::pushContext(*this);
506 }
507 
508 void PyMlirContext::contextExit(const pybind11::object &excType,
509                                 const pybind11::object &excVal,
510                                 const pybind11::object &excTb) {
511   PyThreadContextEntry::popContext(*this);
512 }
513 
514 PyMlirContext &DefaultingPyMlirContext::resolve() {
515   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
516   if (!context) {
517     throw SetPyError(
518         PyExc_RuntimeError,
519         "An MLIR function requires a Context but none was provided in the call "
520         "or from the surrounding environment. Either pass to the function with "
521         "a 'context=' argument or establish a default using 'with Context():'");
522   }
523   return *context;
524 }
525 
526 //------------------------------------------------------------------------------
527 // PyThreadContextEntry management
528 //------------------------------------------------------------------------------
529 
530 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
531   static thread_local std::vector<PyThreadContextEntry> stack;
532   return stack;
533 }
534 
535 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
536   auto &stack = getStack();
537   if (stack.empty())
538     return nullptr;
539   return &stack.back();
540 }
541 
542 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
543                                 py::object insertionPoint,
544                                 py::object location) {
545   auto &stack = getStack();
546   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
547                      std::move(location));
548   // If the new stack has more than one entry and the context of the new top
549   // entry matches the previous, copy the insertionPoint and location from the
550   // previous entry if missing from the new top entry.
551   if (stack.size() > 1) {
552     auto &prev = *(stack.rbegin() + 1);
553     auto &current = stack.back();
554     if (current.context.is(prev.context)) {
555       // Default non-context objects from the previous entry.
556       if (!current.insertionPoint)
557         current.insertionPoint = prev.insertionPoint;
558       if (!current.location)
559         current.location = prev.location;
560     }
561   }
562 }
563 
564 PyMlirContext *PyThreadContextEntry::getContext() {
565   if (!context)
566     return nullptr;
567   return py::cast<PyMlirContext *>(context);
568 }
569 
570 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
571   if (!insertionPoint)
572     return nullptr;
573   return py::cast<PyInsertionPoint *>(insertionPoint);
574 }
575 
576 PyLocation *PyThreadContextEntry::getLocation() {
577   if (!location)
578     return nullptr;
579   return py::cast<PyLocation *>(location);
580 }
581 
582 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
583   auto *tos = getTopOfStack();
584   return tos ? tos->getContext() : nullptr;
585 }
586 
587 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
588   auto *tos = getTopOfStack();
589   return tos ? tos->getInsertionPoint() : nullptr;
590 }
591 
592 PyLocation *PyThreadContextEntry::getDefaultLocation() {
593   auto *tos = getTopOfStack();
594   return tos ? tos->getLocation() : nullptr;
595 }
596 
597 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
598   py::object contextObj = py::cast(context);
599   push(FrameKind::Context, /*context=*/contextObj,
600        /*insertionPoint=*/py::object(),
601        /*location=*/py::object());
602   return contextObj;
603 }
604 
605 void PyThreadContextEntry::popContext(PyMlirContext &context) {
606   auto &stack = getStack();
607   if (stack.empty())
608     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
609   auto &tos = stack.back();
610   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
611     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
612   stack.pop_back();
613 }
614 
615 py::object
616 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
617   py::object contextObj =
618       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
619   py::object insertionPointObj = py::cast(insertionPoint);
620   push(FrameKind::InsertionPoint,
621        /*context=*/contextObj,
622        /*insertionPoint=*/insertionPointObj,
623        /*location=*/py::object());
624   return insertionPointObj;
625 }
626 
627 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
628   auto &stack = getStack();
629   if (stack.empty())
630     throw SetPyError(PyExc_RuntimeError,
631                      "Unbalanced InsertionPoint enter/exit");
632   auto &tos = stack.back();
633   if (tos.frameKind != FrameKind::InsertionPoint &&
634       tos.getInsertionPoint() != &insertionPoint)
635     throw SetPyError(PyExc_RuntimeError,
636                      "Unbalanced InsertionPoint enter/exit");
637   stack.pop_back();
638 }
639 
640 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
641   py::object contextObj = location.getContext().getObject();
642   py::object locationObj = py::cast(location);
643   push(FrameKind::Location, /*context=*/contextObj,
644        /*insertionPoint=*/py::object(),
645        /*location=*/locationObj);
646   return locationObj;
647 }
648 
649 void PyThreadContextEntry::popLocation(PyLocation &location) {
650   auto &stack = getStack();
651   if (stack.empty())
652     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
653   auto &tos = stack.back();
654   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
655     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
656   stack.pop_back();
657 }
658 
659 //------------------------------------------------------------------------------
660 // PyDialect, PyDialectDescriptor, PyDialects
661 //------------------------------------------------------------------------------
662 
663 MlirDialect PyDialects::getDialectForKey(const std::string &key,
664                                          bool attrError) {
665   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
666                                                     {key.data(), key.size()});
667   if (mlirDialectIsNull(dialect)) {
668     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
669                      Twine("Dialect '") + key + "' not found");
670   }
671   return dialect;
672 }
673 
674 //------------------------------------------------------------------------------
675 // PyLocation
676 //------------------------------------------------------------------------------
677 
678 py::object PyLocation::getCapsule() {
679   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
680 }
681 
682 PyLocation PyLocation::createFromCapsule(py::object capsule) {
683   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
684   if (mlirLocationIsNull(rawLoc))
685     throw py::error_already_set();
686   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
687                     rawLoc);
688 }
689 
690 py::object PyLocation::contextEnter() {
691   return PyThreadContextEntry::pushLocation(*this);
692 }
693 
694 void PyLocation::contextExit(const pybind11::object &excType,
695                              const pybind11::object &excVal,
696                              const pybind11::object &excTb) {
697   PyThreadContextEntry::popLocation(*this);
698 }
699 
700 PyLocation &DefaultingPyLocation::resolve() {
701   auto *location = PyThreadContextEntry::getDefaultLocation();
702   if (!location) {
703     throw SetPyError(
704         PyExc_RuntimeError,
705         "An MLIR function requires a Location but none was provided in the "
706         "call or from the surrounding environment. Either pass to the function "
707         "with a 'loc=' argument or establish a default using 'with loc:'");
708   }
709   return *location;
710 }
711 
712 //------------------------------------------------------------------------------
713 // PyModule
714 //------------------------------------------------------------------------------
715 
716 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
717     : BaseContextObject(std::move(contextRef)), module(module) {}
718 
719 PyModule::~PyModule() {
720   py::gil_scoped_acquire acquire;
721   auto &liveModules = getContext()->liveModules;
722   assert(liveModules.count(module.ptr) == 1 &&
723          "destroying module not in live map");
724   liveModules.erase(module.ptr);
725   mlirModuleDestroy(module);
726 }
727 
728 PyModuleRef PyModule::forModule(MlirModule module) {
729   MlirContext context = mlirModuleGetContext(module);
730   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
731 
732   py::gil_scoped_acquire acquire;
733   auto &liveModules = contextRef->liveModules;
734   auto it = liveModules.find(module.ptr);
735   if (it == liveModules.end()) {
736     // Create.
737     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
738     // Note that the default return value policy on cast is automatic_reference,
739     // which does not take ownership (delete will not be called).
740     // Just be explicit.
741     py::object pyRef =
742         py::cast(unownedModule, py::return_value_policy::take_ownership);
743     unownedModule->handle = pyRef;
744     liveModules[module.ptr] =
745         std::make_pair(unownedModule->handle, unownedModule);
746     return PyModuleRef(unownedModule, std::move(pyRef));
747   }
748   // Use existing.
749   PyModule *existing = it->second.second;
750   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
751   return PyModuleRef(existing, std::move(pyRef));
752 }
753 
754 py::object PyModule::createFromCapsule(py::object capsule) {
755   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
756   if (mlirModuleIsNull(rawModule))
757     throw py::error_already_set();
758   return forModule(rawModule).releaseObject();
759 }
760 
761 py::object PyModule::getCapsule() {
762   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
763 }
764 
765 //------------------------------------------------------------------------------
766 // PyOperation
767 //------------------------------------------------------------------------------
768 
769 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
770     : BaseContextObject(std::move(contextRef)), operation(operation) {}
771 
772 PyOperation::~PyOperation() {
773   // If the operation has already been invalidated there is nothing to do.
774   if (!valid)
775     return;
776   auto &liveOperations = getContext()->liveOperations;
777   assert(liveOperations.count(operation.ptr) == 1 &&
778          "destroying operation not in live map");
779   liveOperations.erase(operation.ptr);
780   if (!isAttached()) {
781     mlirOperationDestroy(operation);
782   }
783 }
784 
785 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
786                                            MlirOperation operation,
787                                            py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   // Create.
790   PyOperation *unownedOperation =
791       new PyOperation(std::move(contextRef), operation);
792   // Note that the default return value policy on cast is automatic_reference,
793   // which does not take ownership (delete will not be called).
794   // Just be explicit.
795   py::object pyRef =
796       py::cast(unownedOperation, py::return_value_policy::take_ownership);
797   unownedOperation->handle = pyRef;
798   if (parentKeepAlive) {
799     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
800   }
801   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
802   return PyOperationRef(unownedOperation, std::move(pyRef));
803 }
804 
805 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
806                                          MlirOperation operation,
807                                          py::object parentKeepAlive) {
808   auto &liveOperations = contextRef->liveOperations;
809   auto it = liveOperations.find(operation.ptr);
810   if (it == liveOperations.end()) {
811     // Create.
812     return createInstance(std::move(contextRef), operation,
813                           std::move(parentKeepAlive));
814   }
815   // Use existing.
816   PyOperation *existing = it->second.second;
817   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
818   return PyOperationRef(existing, std::move(pyRef));
819 }
820 
821 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
822                                            MlirOperation operation,
823                                            py::object parentKeepAlive) {
824   auto &liveOperations = contextRef->liveOperations;
825   assert(liveOperations.count(operation.ptr) == 0 &&
826          "cannot create detached operation that already exists");
827   (void)liveOperations;
828 
829   PyOperationRef created = createInstance(std::move(contextRef), operation,
830                                           std::move(parentKeepAlive));
831   created->attached = false;
832   return created;
833 }
834 
835 void PyOperation::checkValid() const {
836   if (!valid) {
837     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
838   }
839 }
840 
841 void PyOperationBase::print(py::object fileObject, bool binary,
842                             llvm::Optional<int64_t> largeElementsLimit,
843                             bool enableDebugInfo, bool prettyDebugInfo,
844                             bool printGenericOpForm, bool useLocalScope,
845                             bool assumeVerified) {
846   PyOperation &operation = getOperation();
847   operation.checkValid();
848   if (fileObject.is_none())
849     fileObject = py::module::import("sys").attr("stdout");
850 
851   if (!assumeVerified && !printGenericOpForm &&
852       !mlirOperationVerify(operation)) {
853     std::string message("// Verification failed, printing generic form\n");
854     if (binary) {
855       fileObject.attr("write")(py::bytes(message));
856     } else {
857       fileObject.attr("write")(py::str(message));
858     }
859     printGenericOpForm = true;
860   }
861 
862   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
863   if (largeElementsLimit)
864     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
865   if (enableDebugInfo)
866     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
867   if (printGenericOpForm)
868     mlirOpPrintingFlagsPrintGenericOpForm(flags);
869 
870   PyFileAccumulator accum(fileObject, binary);
871   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
872                               accum.getUserData());
873   mlirOpPrintingFlagsDestroy(flags);
874 }
875 
876 py::object PyOperationBase::getAsm(bool binary,
877                                    llvm::Optional<int64_t> largeElementsLimit,
878                                    bool enableDebugInfo, bool prettyDebugInfo,
879                                    bool printGenericOpForm, bool useLocalScope,
880                                    bool assumeVerified) {
881   py::object fileObject;
882   if (binary) {
883     fileObject = py::module::import("io").attr("BytesIO")();
884   } else {
885     fileObject = py::module::import("io").attr("StringIO")();
886   }
887   print(fileObject, /*binary=*/binary,
888         /*largeElementsLimit=*/largeElementsLimit,
889         /*enableDebugInfo=*/enableDebugInfo,
890         /*prettyDebugInfo=*/prettyDebugInfo,
891         /*printGenericOpForm=*/printGenericOpForm,
892         /*useLocalScope=*/useLocalScope,
893         /*assumeVerified=*/assumeVerified);
894 
895   return fileObject.attr("getvalue")();
896 }
897 
898 void PyOperationBase::moveAfter(PyOperationBase &other) {
899   PyOperation &operation = getOperation();
900   PyOperation &otherOp = other.getOperation();
901   operation.checkValid();
902   otherOp.checkValid();
903   mlirOperationMoveAfter(operation, otherOp);
904   operation.parentKeepAlive = otherOp.parentKeepAlive;
905 }
906 
907 void PyOperationBase::moveBefore(PyOperationBase &other) {
908   PyOperation &operation = getOperation();
909   PyOperation &otherOp = other.getOperation();
910   operation.checkValid();
911   otherOp.checkValid();
912   mlirOperationMoveBefore(operation, otherOp);
913   operation.parentKeepAlive = otherOp.parentKeepAlive;
914 }
915 
916 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
917   checkValid();
918   if (!isAttached())
919     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
920   MlirOperation operation = mlirOperationGetParentOperation(get());
921   if (mlirOperationIsNull(operation))
922     return {};
923   return PyOperation::forOperation(getContext(), operation);
924 }
925 
926 PyBlock PyOperation::getBlock() {
927   checkValid();
928   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
929   MlirBlock block = mlirOperationGetBlock(get());
930   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
931   assert(parentOperation && "Operation has no parent");
932   return PyBlock{std::move(*parentOperation), block};
933 }
934 
935 py::object PyOperation::getCapsule() {
936   checkValid();
937   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
938 }
939 
940 py::object PyOperation::createFromCapsule(py::object capsule) {
941   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
942   if (mlirOperationIsNull(rawOperation))
943     throw py::error_already_set();
944   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
945   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
946       .releaseObject();
947 }
948 
949 py::object PyOperation::create(
950     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
951     llvm::Optional<std::vector<PyValue *>> operands,
952     llvm::Optional<py::dict> attributes,
953     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
954     DefaultingPyLocation location, const py::object &maybeIp) {
955   llvm::SmallVector<MlirValue, 4> mlirOperands;
956   llvm::SmallVector<MlirType, 4> mlirResults;
957   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
958   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
959 
960   // General parameter validation.
961   if (regions < 0)
962     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
963 
964   // Unpack/validate operands.
965   if (operands) {
966     mlirOperands.reserve(operands->size());
967     for (PyValue *operand : *operands) {
968       if (!operand)
969         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
970       mlirOperands.push_back(operand->get());
971     }
972   }
973 
974   // Unpack/validate results.
975   if (results) {
976     mlirResults.reserve(results->size());
977     for (PyType *result : *results) {
978       // TODO: Verify result type originate from the same context.
979       if (!result)
980         throw SetPyError(PyExc_ValueError, "result type cannot be None");
981       mlirResults.push_back(*result);
982     }
983   }
984   // Unpack/validate attributes.
985   if (attributes) {
986     mlirAttributes.reserve(attributes->size());
987     for (auto &it : *attributes) {
988       std::string key;
989       try {
990         key = it.first.cast<std::string>();
991       } catch (py::cast_error &err) {
992         std::string msg = "Invalid attribute key (not a string) when "
993                           "attempting to create the operation \"" +
994                           name + "\" (" + err.what() + ")";
995         throw py::cast_error(msg);
996       }
997       try {
998         auto &attribute = it.second.cast<PyAttribute &>();
999         // TODO: Verify attribute originates from the same context.
1000         mlirAttributes.emplace_back(std::move(key), attribute);
1001       } catch (py::reference_cast_error &) {
1002         // This exception seems thrown when the value is "None".
1003         std::string msg =
1004             "Found an invalid (`None`?) attribute value for the key \"" + key +
1005             "\" when attempting to create the operation \"" + name + "\"";
1006         throw py::cast_error(msg);
1007       } catch (py::cast_error &err) {
1008         std::string msg = "Invalid attribute value for the key \"" + key +
1009                           "\" when attempting to create the operation \"" +
1010                           name + "\" (" + err.what() + ")";
1011         throw py::cast_error(msg);
1012       }
1013     }
1014   }
1015   // Unpack/validate successors.
1016   if (successors) {
1017     mlirSuccessors.reserve(successors->size());
1018     for (auto *successor : *successors) {
1019       // TODO: Verify successor originate from the same context.
1020       if (!successor)
1021         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1022       mlirSuccessors.push_back(successor->get());
1023     }
1024   }
1025 
1026   // Apply unpacked/validated to the operation state. Beyond this
1027   // point, exceptions cannot be thrown or else the state will leak.
1028   MlirOperationState state =
1029       mlirOperationStateGet(toMlirStringRef(name), location);
1030   if (!mlirOperands.empty())
1031     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1032                                   mlirOperands.data());
1033   if (!mlirResults.empty())
1034     mlirOperationStateAddResults(&state, mlirResults.size(),
1035                                  mlirResults.data());
1036   if (!mlirAttributes.empty()) {
1037     // Note that the attribute names directly reference bytes in
1038     // mlirAttributes, so that vector must not be changed from here
1039     // on.
1040     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1041     mlirNamedAttributes.reserve(mlirAttributes.size());
1042     for (auto &it : mlirAttributes)
1043       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1044           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1045                             toMlirStringRef(it.first)),
1046           it.second));
1047     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1048                                     mlirNamedAttributes.data());
1049   }
1050   if (!mlirSuccessors.empty())
1051     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1052                                     mlirSuccessors.data());
1053   if (regions) {
1054     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1055     mlirRegions.resize(regions);
1056     for (int i = 0; i < regions; ++i)
1057       mlirRegions[i] = mlirRegionCreate();
1058     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1059                                       mlirRegions.data());
1060   }
1061 
1062   // Construct the operation.
1063   MlirOperation operation = mlirOperationCreate(&state);
1064   PyOperationRef created =
1065       PyOperation::createDetached(location->getContext(), operation);
1066 
1067   // InsertPoint active?
1068   if (!maybeIp.is(py::cast(false))) {
1069     PyInsertionPoint *ip;
1070     if (maybeIp.is_none()) {
1071       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1072     } else {
1073       ip = py::cast<PyInsertionPoint *>(maybeIp);
1074     }
1075     if (ip)
1076       ip->insert(*created.get());
1077   }
1078 
1079   return created->createOpView();
1080 }
1081 
1082 py::object PyOperation::createOpView() {
1083   checkValid();
1084   MlirIdentifier ident = mlirOperationGetName(get());
1085   MlirStringRef identStr = mlirIdentifierStr(ident);
1086   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1087       StringRef(identStr.data, identStr.length));
1088   if (opViewClass)
1089     return (*opViewClass)(getRef().getObject());
1090   return py::cast(PyOpView(getRef().getObject()));
1091 }
1092 
1093 void PyOperation::erase() {
1094   checkValid();
1095   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1096   // Python reference to a child operation is live. All children should also
1097   // have their `valid` bit set to false.
1098   auto &liveOperations = getContext()->liveOperations;
1099   if (liveOperations.count(operation.ptr))
1100     liveOperations.erase(operation.ptr);
1101   mlirOperationDestroy(operation);
1102   valid = false;
1103 }
1104 
1105 //------------------------------------------------------------------------------
1106 // PyOpView
1107 //------------------------------------------------------------------------------
1108 
1109 py::object PyOpView::buildGeneric(
1110     const py::object &cls, py::list resultTypeList, py::list operandList,
1111     llvm::Optional<py::dict> attributes,
1112     llvm::Optional<std::vector<PyBlock *>> successors,
1113     llvm::Optional<int> regions, DefaultingPyLocation location,
1114     const py::object &maybeIp) {
1115   PyMlirContextRef context = location->getContext();
1116   // Class level operation construction metadata.
1117   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1118   // Operand and result segment specs are either none, which does no
1119   // variadic unpacking, or a list of ints with segment sizes, where each
1120   // element is either a positive number (typically 1 for a scalar) or -1 to
1121   // indicate that it is derived from the length of the same-indexed operand
1122   // or result (implying that it is a list at that position).
1123   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1124   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1125 
1126   std::vector<uint32_t> operandSegmentLengths;
1127   std::vector<uint32_t> resultSegmentLengths;
1128 
1129   // Validate/determine region count.
1130   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1131   int opMinRegionCount = std::get<0>(opRegionSpec);
1132   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1133   if (!regions) {
1134     regions = opMinRegionCount;
1135   }
1136   if (*regions < opMinRegionCount) {
1137     throw py::value_error(
1138         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1139          llvm::Twine(opMinRegionCount) +
1140          " regions but was built with regions=" + llvm::Twine(*regions))
1141             .str());
1142   }
1143   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1144     throw py::value_error(
1145         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1146          llvm::Twine(opMinRegionCount) +
1147          " regions but was built with regions=" + llvm::Twine(*regions))
1148             .str());
1149   }
1150 
1151   // Unpack results.
1152   std::vector<PyType *> resultTypes;
1153   resultTypes.reserve(resultTypeList.size());
1154   if (resultSegmentSpecObj.is_none()) {
1155     // Non-variadic result unpacking.
1156     for (const auto &it : llvm::enumerate(resultTypeList)) {
1157       try {
1158         resultTypes.push_back(py::cast<PyType *>(it.value()));
1159         if (!resultTypes.back())
1160           throw py::cast_error();
1161       } catch (py::cast_error &err) {
1162         throw py::value_error((llvm::Twine("Result ") +
1163                                llvm::Twine(it.index()) + " of operation \"" +
1164                                name + "\" must be a Type (" + err.what() + ")")
1165                                   .str());
1166       }
1167     }
1168   } else {
1169     // Sized result unpacking.
1170     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1171     if (resultSegmentSpec.size() != resultTypeList.size()) {
1172       throw py::value_error((llvm::Twine("Operation \"") + name +
1173                              "\" requires " +
1174                              llvm::Twine(resultSegmentSpec.size()) +
1175                              " result segments but was provided " +
1176                              llvm::Twine(resultTypeList.size()))
1177                                 .str());
1178     }
1179     resultSegmentLengths.reserve(resultTypeList.size());
1180     for (const auto &it :
1181          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1182       int segmentSpec = std::get<1>(it.value());
1183       if (segmentSpec == 1 || segmentSpec == 0) {
1184         // Unpack unary element.
1185         try {
1186           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1187           if (resultType) {
1188             resultTypes.push_back(resultType);
1189             resultSegmentLengths.push_back(1);
1190           } else if (segmentSpec == 0) {
1191             // Allowed to be optional.
1192             resultSegmentLengths.push_back(0);
1193           } else {
1194             throw py::cast_error("was None and result is not optional");
1195           }
1196         } catch (py::cast_error &err) {
1197           throw py::value_error((llvm::Twine("Result ") +
1198                                  llvm::Twine(it.index()) + " of operation \"" +
1199                                  name + "\" must be a Type (" + err.what() +
1200                                  ")")
1201                                     .str());
1202         }
1203       } else if (segmentSpec == -1) {
1204         // Unpack sequence by appending.
1205         try {
1206           if (std::get<0>(it.value()).is_none()) {
1207             // Treat it as an empty list.
1208             resultSegmentLengths.push_back(0);
1209           } else {
1210             // Unpack the list.
1211             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1212             for (py::object segmentItem : segment) {
1213               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1214               if (!resultTypes.back()) {
1215                 throw py::cast_error("contained a None item");
1216               }
1217             }
1218             resultSegmentLengths.push_back(segment.size());
1219           }
1220         } catch (std::exception &err) {
1221           // NOTE: Sloppy to be using a catch-all here, but there are at least
1222           // three different unrelated exceptions that can be thrown in the
1223           // above "casts". Just keep the scope above small and catch them all.
1224           throw py::value_error((llvm::Twine("Result ") +
1225                                  llvm::Twine(it.index()) + " of operation \"" +
1226                                  name + "\" must be a Sequence of Types (" +
1227                                  err.what() + ")")
1228                                     .str());
1229         }
1230       } else {
1231         throw py::value_error("Unexpected segment spec");
1232       }
1233     }
1234   }
1235 
1236   // Unpack operands.
1237   std::vector<PyValue *> operands;
1238   operands.reserve(operands.size());
1239   if (operandSegmentSpecObj.is_none()) {
1240     // Non-sized operand unpacking.
1241     for (const auto &it : llvm::enumerate(operandList)) {
1242       try {
1243         operands.push_back(py::cast<PyValue *>(it.value()));
1244         if (!operands.back())
1245           throw py::cast_error();
1246       } catch (py::cast_error &err) {
1247         throw py::value_error((llvm::Twine("Operand ") +
1248                                llvm::Twine(it.index()) + " of operation \"" +
1249                                name + "\" must be a Value (" + err.what() + ")")
1250                                   .str());
1251       }
1252     }
1253   } else {
1254     // Sized operand unpacking.
1255     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1256     if (operandSegmentSpec.size() != operandList.size()) {
1257       throw py::value_error((llvm::Twine("Operation \"") + name +
1258                              "\" requires " +
1259                              llvm::Twine(operandSegmentSpec.size()) +
1260                              "operand segments but was provided " +
1261                              llvm::Twine(operandList.size()))
1262                                 .str());
1263     }
1264     operandSegmentLengths.reserve(operandList.size());
1265     for (const auto &it :
1266          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1267       int segmentSpec = std::get<1>(it.value());
1268       if (segmentSpec == 1 || segmentSpec == 0) {
1269         // Unpack unary element.
1270         try {
1271           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1272           if (operandValue) {
1273             operands.push_back(operandValue);
1274             operandSegmentLengths.push_back(1);
1275           } else if (segmentSpec == 0) {
1276             // Allowed to be optional.
1277             operandSegmentLengths.push_back(0);
1278           } else {
1279             throw py::cast_error("was None and operand is not optional");
1280           }
1281         } catch (py::cast_error &err) {
1282           throw py::value_error((llvm::Twine("Operand ") +
1283                                  llvm::Twine(it.index()) + " of operation \"" +
1284                                  name + "\" must be a Value (" + err.what() +
1285                                  ")")
1286                                     .str());
1287         }
1288       } else if (segmentSpec == -1) {
1289         // Unpack sequence by appending.
1290         try {
1291           if (std::get<0>(it.value()).is_none()) {
1292             // Treat it as an empty list.
1293             operandSegmentLengths.push_back(0);
1294           } else {
1295             // Unpack the list.
1296             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1297             for (py::object segmentItem : segment) {
1298               operands.push_back(py::cast<PyValue *>(segmentItem));
1299               if (!operands.back()) {
1300                 throw py::cast_error("contained a None item");
1301               }
1302             }
1303             operandSegmentLengths.push_back(segment.size());
1304           }
1305         } catch (std::exception &err) {
1306           // NOTE: Sloppy to be using a catch-all here, but there are at least
1307           // three different unrelated exceptions that can be thrown in the
1308           // above "casts". Just keep the scope above small and catch them all.
1309           throw py::value_error((llvm::Twine("Operand ") +
1310                                  llvm::Twine(it.index()) + " of operation \"" +
1311                                  name + "\" must be a Sequence of Values (" +
1312                                  err.what() + ")")
1313                                     .str());
1314         }
1315       } else {
1316         throw py::value_error("Unexpected segment spec");
1317       }
1318     }
1319   }
1320 
1321   // Merge operand/result segment lengths into attributes if needed.
1322   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1323     // Dup.
1324     if (attributes) {
1325       attributes = py::dict(*attributes);
1326     } else {
1327       attributes = py::dict();
1328     }
1329     if (attributes->contains("result_segment_sizes") ||
1330         attributes->contains("operand_segment_sizes")) {
1331       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1332                             "'operand_segment_sizes' attribute is unsupported. "
1333                             "Use Operation.create for such low-level access.");
1334     }
1335 
1336     // Add result_segment_sizes attribute.
1337     if (!resultSegmentLengths.empty()) {
1338       int64_t size = resultSegmentLengths.size();
1339       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1340           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1341           resultSegmentLengths.size(), resultSegmentLengths.data());
1342       (*attributes)["result_segment_sizes"] =
1343           PyAttribute(context, segmentLengthAttr);
1344     }
1345 
1346     // Add operand_segment_sizes attribute.
1347     if (!operandSegmentLengths.empty()) {
1348       int64_t size = operandSegmentLengths.size();
1349       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1350           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1351           operandSegmentLengths.size(), operandSegmentLengths.data());
1352       (*attributes)["operand_segment_sizes"] =
1353           PyAttribute(context, segmentLengthAttr);
1354     }
1355   }
1356 
1357   // Delegate to create.
1358   return PyOperation::create(name,
1359                              /*results=*/std::move(resultTypes),
1360                              /*operands=*/std::move(operands),
1361                              /*attributes=*/std::move(attributes),
1362                              /*successors=*/std::move(successors),
1363                              /*regions=*/*regions, location, maybeIp);
1364 }
1365 
1366 PyOpView::PyOpView(const py::object &operationObject)
1367     // Casting through the PyOperationBase base-class and then back to the
1368     // Operation lets us accept any PyOperationBase subclass.
1369     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1370       operationObject(operation.getRef().getObject()) {}
1371 
1372 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1373   // This is... a little gross. The typical pattern is to have a pure python
1374   // class that extends OpView like:
1375   //   class AddFOp(_cext.ir.OpView):
1376   //     def __init__(self, loc, lhs, rhs):
1377   //       operation = loc.context.create_operation(
1378   //           "addf", lhs, rhs, results=[lhs.type])
1379   //       super().__init__(operation)
1380   //
1381   // I.e. The goal of the user facing type is to provide a nice constructor
1382   // that has complete freedom for the op under construction. This is at odds
1383   // with our other desire to sometimes create this object by just passing an
1384   // operation (to initialize the base class). We could do *arg and **kwargs
1385   // munging to try to make it work, but instead, we synthesize a new class
1386   // on the fly which extends this user class (AddFOp in this example) and
1387   // *give it* the base class's __init__ method, thus bypassing the
1388   // intermediate subclass's __init__ method entirely. While slightly,
1389   // underhanded, this is safe/legal because the type hierarchy has not changed
1390   // (we just added a new leaf) and we aren't mucking around with __new__.
1391   // Typically, this new class will be stored on the original as "_Raw" and will
1392   // be used for casts and other things that need a variant of the class that
1393   // is initialized purely from an operation.
1394   py::object parentMetaclass =
1395       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1396   py::dict attributes;
1397   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1398   // now.
1399   //   auto opViewType = py::type::of<PyOpView>();
1400   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1401   attributes["__init__"] = opViewType.attr("__init__");
1402   py::str origName = userClass.attr("__name__");
1403   py::str newName = py::str("_") + origName;
1404   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1405 }
1406 
1407 //------------------------------------------------------------------------------
1408 // PyInsertionPoint.
1409 //------------------------------------------------------------------------------
1410 
1411 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1412 
1413 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1414     : refOperation(beforeOperationBase.getOperation().getRef()),
1415       block((*refOperation)->getBlock()) {}
1416 
1417 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1418   PyOperation &operation = operationBase.getOperation();
1419   if (operation.isAttached())
1420     throw SetPyError(PyExc_ValueError,
1421                      "Attempt to insert operation that is already attached");
1422   block.getParentOperation()->checkValid();
1423   MlirOperation beforeOp = {nullptr};
1424   if (refOperation) {
1425     // Insert before operation.
1426     (*refOperation)->checkValid();
1427     beforeOp = (*refOperation)->get();
1428   } else {
1429     // Insert at end (before null) is only valid if the block does not
1430     // already end in a known terminator (violating this will cause assertion
1431     // failures later).
1432     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1433       throw py::index_error("Cannot insert operation at the end of a block "
1434                             "that already has a terminator. Did you mean to "
1435                             "use 'InsertionPoint.at_block_terminator(block)' "
1436                             "versus 'InsertionPoint(block)'?");
1437     }
1438   }
1439   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1440   operation.setAttached();
1441 }
1442 
1443 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1444   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1445   if (mlirOperationIsNull(firstOp)) {
1446     // Just insert at end.
1447     return PyInsertionPoint(block);
1448   }
1449 
1450   // Insert before first op.
1451   PyOperationRef firstOpRef = PyOperation::forOperation(
1452       block.getParentOperation()->getContext(), firstOp);
1453   return PyInsertionPoint{block, std::move(firstOpRef)};
1454 }
1455 
1456 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1457   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1458   if (mlirOperationIsNull(terminator))
1459     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1460   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1461       block.getParentOperation()->getContext(), terminator);
1462   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1463 }
1464 
1465 py::object PyInsertionPoint::contextEnter() {
1466   return PyThreadContextEntry::pushInsertionPoint(*this);
1467 }
1468 
1469 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1470                                    const pybind11::object &excVal,
1471                                    const pybind11::object &excTb) {
1472   PyThreadContextEntry::popInsertionPoint(*this);
1473 }
1474 
1475 //------------------------------------------------------------------------------
1476 // PyAttribute.
1477 //------------------------------------------------------------------------------
1478 
1479 bool PyAttribute::operator==(const PyAttribute &other) {
1480   return mlirAttributeEqual(attr, other.attr);
1481 }
1482 
1483 py::object PyAttribute::getCapsule() {
1484   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1485 }
1486 
1487 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1488   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1489   if (mlirAttributeIsNull(rawAttr))
1490     throw py::error_already_set();
1491   return PyAttribute(
1492       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1493 }
1494 
1495 //------------------------------------------------------------------------------
1496 // PyNamedAttribute.
1497 //------------------------------------------------------------------------------
1498 
1499 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1500     : ownedName(new std::string(std::move(ownedName))) {
1501   namedAttr = mlirNamedAttributeGet(
1502       mlirIdentifierGet(mlirAttributeGetContext(attr),
1503                         toMlirStringRef(*this->ownedName)),
1504       attr);
1505 }
1506 
1507 //------------------------------------------------------------------------------
1508 // PyType.
1509 //------------------------------------------------------------------------------
1510 
1511 bool PyType::operator==(const PyType &other) {
1512   return mlirTypeEqual(type, other.type);
1513 }
1514 
1515 py::object PyType::getCapsule() {
1516   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1517 }
1518 
1519 PyType PyType::createFromCapsule(py::object capsule) {
1520   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1521   if (mlirTypeIsNull(rawType))
1522     throw py::error_already_set();
1523   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1524                 rawType);
1525 }
1526 
1527 //------------------------------------------------------------------------------
1528 // PyValue and subclases.
1529 //------------------------------------------------------------------------------
1530 
1531 pybind11::object PyValue::getCapsule() {
1532   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1533 }
1534 
1535 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1536   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1537   if (mlirValueIsNull(value))
1538     throw py::error_already_set();
1539   MlirOperation owner;
1540   if (mlirValueIsAOpResult(value))
1541     owner = mlirOpResultGetOwner(value);
1542   if (mlirValueIsABlockArgument(value))
1543     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1544   if (mlirOperationIsNull(owner))
1545     throw py::error_already_set();
1546   MlirContext ctx = mlirOperationGetContext(owner);
1547   PyOperationRef ownerRef =
1548       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1549   return PyValue(ownerRef, value);
1550 }
1551 
1552 //------------------------------------------------------------------------------
1553 // PySymbolTable.
1554 //------------------------------------------------------------------------------
1555 
1556 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1557     : operation(operation.getOperation().getRef()) {
1558   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1559   if (mlirSymbolTableIsNull(symbolTable)) {
1560     throw py::cast_error("Operation is not a Symbol Table.");
1561   }
1562 }
1563 
1564 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1565   operation->checkValid();
1566   MlirOperation symbol = mlirSymbolTableLookup(
1567       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1568   if (mlirOperationIsNull(symbol))
1569     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1570 
1571   return PyOperation::forOperation(operation->getContext(), symbol,
1572                                    operation.getObject())
1573       ->createOpView();
1574 }
1575 
1576 void PySymbolTable::erase(PyOperationBase &symbol) {
1577   operation->checkValid();
1578   symbol.getOperation().checkValid();
1579   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1580   // The operation is also erased, so we must invalidate it. There may be Python
1581   // references to this operation so we don't want to delete it from the list of
1582   // live operations here.
1583   symbol.getOperation().valid = false;
1584 }
1585 
1586 void PySymbolTable::dunderDel(const std::string &name) {
1587   py::object operation = dunderGetItem(name);
1588   erase(py::cast<PyOperationBase &>(operation));
1589 }
1590 
1591 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1592   operation->checkValid();
1593   symbol.getOperation().checkValid();
1594   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1595       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1596   if (mlirAttributeIsNull(symbolAttr))
1597     throw py::value_error("Expected operation to have a symbol name.");
1598   return PyAttribute(
1599       symbol.getOperation().getContext(),
1600       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1601 }
1602 
1603 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1604   // Op must already be a symbol.
1605   PyOperation &operation = symbol.getOperation();
1606   operation.checkValid();
1607   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1608   MlirAttribute existingNameAttr =
1609       mlirOperationGetAttributeByName(operation.get(), attrName);
1610   if (mlirAttributeIsNull(existingNameAttr))
1611     throw py::value_error("Expected operation to have a symbol name.");
1612   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1613 }
1614 
1615 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1616                                   const std::string &name) {
1617   // Op must already be a symbol.
1618   PyOperation &operation = symbol.getOperation();
1619   operation.checkValid();
1620   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1621   MlirAttribute existingNameAttr =
1622       mlirOperationGetAttributeByName(operation.get(), attrName);
1623   if (mlirAttributeIsNull(existingNameAttr))
1624     throw py::value_error("Expected operation to have a symbol name.");
1625   MlirAttribute newNameAttr =
1626       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1627   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1628 }
1629 
1630 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1631   PyOperation &operation = symbol.getOperation();
1632   operation.checkValid();
1633   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1634   MlirAttribute existingVisAttr =
1635       mlirOperationGetAttributeByName(operation.get(), attrName);
1636   if (mlirAttributeIsNull(existingVisAttr))
1637     throw py::value_error("Expected operation to have a symbol visibility.");
1638   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1639 }
1640 
1641 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1642                                   const std::string &visibility) {
1643   if (visibility != "public" && visibility != "private" &&
1644       visibility != "nested")
1645     throw py::value_error(
1646         "Expected visibility to be 'public', 'private' or 'nested'");
1647   PyOperation &operation = symbol.getOperation();
1648   operation.checkValid();
1649   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1650   MlirAttribute existingVisAttr =
1651       mlirOperationGetAttributeByName(operation.get(), attrName);
1652   if (mlirAttributeIsNull(existingVisAttr))
1653     throw py::value_error("Expected operation to have a symbol visibility.");
1654   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1655                                                toMlirStringRef(visibility));
1656   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1657 }
1658 
1659 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1660                                          const std::string &newSymbol,
1661                                          PyOperationBase &from) {
1662   PyOperation &fromOperation = from.getOperation();
1663   fromOperation.checkValid();
1664   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1665           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1666           from.getOperation())))
1667 
1668     throw py::value_error("Symbol rename failed");
1669 }
1670 
1671 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1672                                      bool allSymUsesVisible,
1673                                      py::object callback) {
1674   PyOperation &fromOperation = from.getOperation();
1675   fromOperation.checkValid();
1676   struct UserData {
1677     PyMlirContextRef context;
1678     py::object callback;
1679     bool gotException;
1680     std::string exceptionWhat;
1681     py::object exceptionType;
1682   };
1683   UserData userData{
1684       fromOperation.getContext(), std::move(callback), false, {}, {}};
1685   mlirSymbolTableWalkSymbolTables(
1686       fromOperation.get(), allSymUsesVisible,
1687       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1688         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1689         auto pyFoundOp =
1690             PyOperation::forOperation(calleeUserData->context, foundOp);
1691         if (calleeUserData->gotException)
1692           return;
1693         try {
1694           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1695         } catch (py::error_already_set &e) {
1696           calleeUserData->gotException = true;
1697           calleeUserData->exceptionWhat = e.what();
1698           calleeUserData->exceptionType = e.type();
1699         }
1700       },
1701       static_cast<void *>(&userData));
1702   if (userData.gotException) {
1703     std::string message("Exception raised in callback: ");
1704     message.append(userData.exceptionWhat);
1705     throw std::runtime_error(message);
1706   }
1707 }
1708 
1709 namespace {
1710 /// CRTP base class for Python MLIR values that subclass Value and should be
1711 /// castable from it. The value hierarchy is one level deep and is not supposed
1712 /// to accommodate other levels unless core MLIR changes.
1713 template <typename DerivedTy>
1714 class PyConcreteValue : public PyValue {
1715 public:
1716   // Derived classes must define statics for:
1717   //   IsAFunctionTy isaFunction
1718   //   const char *pyClassName
1719   // and redefine bindDerived.
1720   using ClassTy = py::class_<DerivedTy, PyValue>;
1721   using IsAFunctionTy = bool (*)(MlirValue);
1722 
1723   PyConcreteValue() = default;
1724   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1725       : PyValue(operationRef, value) {}
1726   PyConcreteValue(PyValue &orig)
1727       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1728 
1729   /// Attempts to cast the original value to the derived type and throws on
1730   /// type mismatches.
1731   static MlirValue castFrom(PyValue &orig) {
1732     if (!DerivedTy::isaFunction(orig.get())) {
1733       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1734       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1735                                              DerivedTy::pyClassName +
1736                                              " (from " + origRepr + ")");
1737     }
1738     return orig.get();
1739   }
1740 
1741   /// Binds the Python module objects to functions of this class.
1742   static void bind(py::module &m) {
1743     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1744     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1745     cls.def_static(
1746         "isinstance",
1747         [](PyValue &otherValue) -> bool {
1748           return DerivedTy::isaFunction(otherValue);
1749         },
1750         py::arg("other_value"));
1751     DerivedTy::bindDerived(cls);
1752   }
1753 
1754   /// Implemented by derived classes to add methods to the Python subclass.
1755   static void bindDerived(ClassTy &m) {}
1756 };
1757 
1758 /// Python wrapper for MlirBlockArgument.
1759 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1760 public:
1761   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1762   static constexpr const char *pyClassName = "BlockArgument";
1763   using PyConcreteValue::PyConcreteValue;
1764 
1765   static void bindDerived(ClassTy &c) {
1766     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1767       return PyBlock(self.getParentOperation(),
1768                      mlirBlockArgumentGetOwner(self.get()));
1769     });
1770     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1771       return mlirBlockArgumentGetArgNumber(self.get());
1772     });
1773     c.def(
1774         "set_type",
1775         [](PyBlockArgument &self, PyType type) {
1776           return mlirBlockArgumentSetType(self.get(), type);
1777         },
1778         py::arg("type"));
1779   }
1780 };
1781 
1782 /// Python wrapper for MlirOpResult.
1783 class PyOpResult : public PyConcreteValue<PyOpResult> {
1784 public:
1785   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1786   static constexpr const char *pyClassName = "OpResult";
1787   using PyConcreteValue::PyConcreteValue;
1788 
1789   static void bindDerived(ClassTy &c) {
1790     c.def_property_readonly("owner", [](PyOpResult &self) {
1791       assert(
1792           mlirOperationEqual(self.getParentOperation()->get(),
1793                              mlirOpResultGetOwner(self.get())) &&
1794           "expected the owner of the value in Python to match that in the IR");
1795       return self.getParentOperation().getObject();
1796     });
1797     c.def_property_readonly("result_number", [](PyOpResult &self) {
1798       return mlirOpResultGetResultNumber(self.get());
1799     });
1800   }
1801 };
1802 
1803 /// Returns the list of types of the values held by container.
1804 template <typename Container>
1805 static std::vector<PyType> getValueTypes(Container &container,
1806                                          PyMlirContextRef &context) {
1807   std::vector<PyType> result;
1808   result.reserve(container.getNumElements());
1809   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1810     result.push_back(
1811         PyType(context, mlirValueGetType(container.getElement(i).get())));
1812   }
1813   return result;
1814 }
1815 
1816 /// A list of block arguments. Internally, these are stored as consecutive
1817 /// elements, random access is cheap. The argument list is associated with the
1818 /// operation that contains the block (detached blocks are not allowed in
1819 /// Python bindings) and extends its lifetime.
1820 class PyBlockArgumentList
1821     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1822 public:
1823   static constexpr const char *pyClassName = "BlockArgumentList";
1824 
1825   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1826                       intptr_t startIndex = 0, intptr_t length = -1,
1827                       intptr_t step = 1)
1828       : Sliceable(startIndex,
1829                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1830                   step),
1831         operation(std::move(operation)), block(block) {}
1832 
1833   /// Returns the number of arguments in the list.
1834   intptr_t getNumElements() {
1835     operation->checkValid();
1836     return mlirBlockGetNumArguments(block);
1837   }
1838 
1839   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1840   PyBlockArgument getElement(intptr_t pos) {
1841     MlirValue argument = mlirBlockGetArgument(block, pos);
1842     return PyBlockArgument(operation, argument);
1843   }
1844 
1845   /// Returns a sublist of this list.
1846   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1847                             intptr_t step) {
1848     return PyBlockArgumentList(operation, block, startIndex, length, step);
1849   }
1850 
1851   static void bindDerived(ClassTy &c) {
1852     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
1853       return getValueTypes(self, self.operation->getContext());
1854     });
1855   }
1856 
1857 private:
1858   PyOperationRef operation;
1859   MlirBlock block;
1860 };
1861 
1862 /// A list of operation operands. Internally, these are stored as consecutive
1863 /// elements, random access is cheap. The result list is associated with the
1864 /// operation whose results these are, and extends the lifetime of this
1865 /// operation.
1866 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1867 public:
1868   static constexpr const char *pyClassName = "OpOperandList";
1869 
1870   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1871                   intptr_t length = -1, intptr_t step = 1)
1872       : Sliceable(startIndex,
1873                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1874                                : length,
1875                   step),
1876         operation(operation) {}
1877 
1878   intptr_t getNumElements() {
1879     operation->checkValid();
1880     return mlirOperationGetNumOperands(operation->get());
1881   }
1882 
1883   PyValue getElement(intptr_t pos) {
1884     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
1885     MlirOperation owner;
1886     if (mlirValueIsAOpResult(operand))
1887       owner = mlirOpResultGetOwner(operand);
1888     else if (mlirValueIsABlockArgument(operand))
1889       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
1890     else
1891       assert(false && "Value must be an block arg or op result.");
1892     PyOperationRef pyOwner =
1893         PyOperation::forOperation(operation->getContext(), owner);
1894     return PyValue(pyOwner, operand);
1895   }
1896 
1897   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1898     return PyOpOperandList(operation, startIndex, length, step);
1899   }
1900 
1901   void dunderSetItem(intptr_t index, PyValue value) {
1902     index = wrapIndex(index);
1903     mlirOperationSetOperand(operation->get(), index, value.get());
1904   }
1905 
1906   static void bindDerived(ClassTy &c) {
1907     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
1908   }
1909 
1910 private:
1911   PyOperationRef operation;
1912 };
1913 
1914 /// A list of operation results. Internally, these are stored as consecutive
1915 /// elements, random access is cheap. The result list is associated with the
1916 /// operation whose results these are, and extends the lifetime of this
1917 /// operation.
1918 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1919 public:
1920   static constexpr const char *pyClassName = "OpResultList";
1921 
1922   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1923                  intptr_t length = -1, intptr_t step = 1)
1924       : Sliceable(startIndex,
1925                   length == -1 ? mlirOperationGetNumResults(operation->get())
1926                                : length,
1927                   step),
1928         operation(operation) {}
1929 
1930   intptr_t getNumElements() {
1931     operation->checkValid();
1932     return mlirOperationGetNumResults(operation->get());
1933   }
1934 
1935   PyOpResult getElement(intptr_t index) {
1936     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1937     return PyOpResult(value);
1938   }
1939 
1940   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1941     return PyOpResultList(operation, startIndex, length, step);
1942   }
1943 
1944   static void bindDerived(ClassTy &c) {
1945     c.def_property_readonly("types", [](PyOpResultList &self) {
1946       return getValueTypes(self, self.operation->getContext());
1947     });
1948   }
1949 
1950 private:
1951   PyOperationRef operation;
1952 };
1953 
1954 /// A list of operation attributes. Can be indexed by name, producing
1955 /// attributes, or by index, producing named attributes.
1956 class PyOpAttributeMap {
1957 public:
1958   PyOpAttributeMap(PyOperationRef operation)
1959       : operation(std::move(operation)) {}
1960 
1961   PyAttribute dunderGetItemNamed(const std::string &name) {
1962     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1963                                                          toMlirStringRef(name));
1964     if (mlirAttributeIsNull(attr)) {
1965       throw SetPyError(PyExc_KeyError,
1966                        "attempt to access a non-existent attribute");
1967     }
1968     return PyAttribute(operation->getContext(), attr);
1969   }
1970 
1971   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1972     if (index < 0 || index >= dunderLen()) {
1973       throw SetPyError(PyExc_IndexError,
1974                        "attempt to access out of bounds attribute");
1975     }
1976     MlirNamedAttribute namedAttr =
1977         mlirOperationGetAttribute(operation->get(), index);
1978     return PyNamedAttribute(
1979         namedAttr.attribute,
1980         std::string(mlirIdentifierStr(namedAttr.name).data,
1981                     mlirIdentifierStr(namedAttr.name).length));
1982   }
1983 
1984   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
1985     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1986                                     attr);
1987   }
1988 
1989   void dunderDelItem(const std::string &name) {
1990     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1991                                                      toMlirStringRef(name));
1992     if (!removed)
1993       throw SetPyError(PyExc_KeyError,
1994                        "attempt to delete a non-existent attribute");
1995   }
1996 
1997   intptr_t dunderLen() {
1998     return mlirOperationGetNumAttributes(operation->get());
1999   }
2000 
2001   bool dunderContains(const std::string &name) {
2002     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2003         operation->get(), toMlirStringRef(name)));
2004   }
2005 
2006   static void bind(py::module &m) {
2007     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2008         .def("__contains__", &PyOpAttributeMap::dunderContains)
2009         .def("__len__", &PyOpAttributeMap::dunderLen)
2010         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2011         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2012         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2013         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2014   }
2015 
2016 private:
2017   PyOperationRef operation;
2018 };
2019 
2020 } // namespace
2021 
2022 //------------------------------------------------------------------------------
2023 // Populates the core exports of the 'ir' submodule.
2024 //------------------------------------------------------------------------------
2025 
2026 void mlir::python::populateIRCore(py::module &m) {
2027   //----------------------------------------------------------------------------
2028   // Mapping of MlirContext.
2029   //----------------------------------------------------------------------------
2030   py::class_<PyMlirContext>(m, "Context", py::module_local())
2031       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2032       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2033       .def("_get_context_again",
2034            [](PyMlirContext &self) {
2035              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2036              return ref.releaseObject();
2037            })
2038       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2039       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2040       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2041                              &PyMlirContext::getCapsule)
2042       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2043       .def("__enter__", &PyMlirContext::contextEnter)
2044       .def("__exit__", &PyMlirContext::contextExit)
2045       .def_property_readonly_static(
2046           "current",
2047           [](py::object & /*class*/) {
2048             auto *context = PyThreadContextEntry::getDefaultContext();
2049             if (!context)
2050               throw SetPyError(PyExc_ValueError, "No current Context");
2051             return context;
2052           },
2053           "Gets the Context bound to the current thread or raises ValueError")
2054       .def_property_readonly(
2055           "dialects",
2056           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2057           "Gets a container for accessing dialects by name")
2058       .def_property_readonly(
2059           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2060           "Alias for 'dialect'")
2061       .def(
2062           "get_dialect_descriptor",
2063           [=](PyMlirContext &self, std::string &name) {
2064             MlirDialect dialect = mlirContextGetOrLoadDialect(
2065                 self.get(), {name.data(), name.size()});
2066             if (mlirDialectIsNull(dialect)) {
2067               throw SetPyError(PyExc_ValueError,
2068                                Twine("Dialect '") + name + "' not found");
2069             }
2070             return PyDialectDescriptor(self.getRef(), dialect);
2071           },
2072           py::arg("dialect_name"),
2073           "Gets or loads a dialect by name, returning its descriptor object")
2074       .def_property(
2075           "allow_unregistered_dialects",
2076           [](PyMlirContext &self) -> bool {
2077             return mlirContextGetAllowUnregisteredDialects(self.get());
2078           },
2079           [](PyMlirContext &self, bool value) {
2080             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2081           })
2082       .def(
2083           "enable_multithreading",
2084           [](PyMlirContext &self, bool enable) {
2085             mlirContextEnableMultithreading(self.get(), enable);
2086           },
2087           py::arg("enable"))
2088       .def(
2089           "is_registered_operation",
2090           [](PyMlirContext &self, std::string &name) {
2091             return mlirContextIsRegisteredOperation(
2092                 self.get(), MlirStringRef{name.data(), name.size()});
2093           },
2094           py::arg("operation_name"));
2095 
2096   //----------------------------------------------------------------------------
2097   // Mapping of PyDialectDescriptor
2098   //----------------------------------------------------------------------------
2099   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2100       .def_property_readonly("namespace",
2101                              [](PyDialectDescriptor &self) {
2102                                MlirStringRef ns =
2103                                    mlirDialectGetNamespace(self.get());
2104                                return py::str(ns.data, ns.length);
2105                              })
2106       .def("__repr__", [](PyDialectDescriptor &self) {
2107         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2108         std::string repr("<DialectDescriptor ");
2109         repr.append(ns.data, ns.length);
2110         repr.append(">");
2111         return repr;
2112       });
2113 
2114   //----------------------------------------------------------------------------
2115   // Mapping of PyDialects
2116   //----------------------------------------------------------------------------
2117   py::class_<PyDialects>(m, "Dialects", py::module_local())
2118       .def("__getitem__",
2119            [=](PyDialects &self, std::string keyName) {
2120              MlirDialect dialect =
2121                  self.getDialectForKey(keyName, /*attrError=*/false);
2122              py::object descriptor =
2123                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2124              return createCustomDialectWrapper(keyName, std::move(descriptor));
2125            })
2126       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2127         MlirDialect dialect =
2128             self.getDialectForKey(attrName, /*attrError=*/true);
2129         py::object descriptor =
2130             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2131         return createCustomDialectWrapper(attrName, std::move(descriptor));
2132       });
2133 
2134   //----------------------------------------------------------------------------
2135   // Mapping of PyDialect
2136   //----------------------------------------------------------------------------
2137   py::class_<PyDialect>(m, "Dialect", py::module_local())
2138       .def(py::init<py::object>(), py::arg("descriptor"))
2139       .def_property_readonly(
2140           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2141       .def("__repr__", [](py::object self) {
2142         auto clazz = self.attr("__class__");
2143         return py::str("<Dialect ") +
2144                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2145                clazz.attr("__module__") + py::str(".") +
2146                clazz.attr("__name__") + py::str(")>");
2147       });
2148 
2149   //----------------------------------------------------------------------------
2150   // Mapping of Location
2151   //----------------------------------------------------------------------------
2152   py::class_<PyLocation>(m, "Location", py::module_local())
2153       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2154       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2155       .def("__enter__", &PyLocation::contextEnter)
2156       .def("__exit__", &PyLocation::contextExit)
2157       .def("__eq__",
2158            [](PyLocation &self, PyLocation &other) -> bool {
2159              return mlirLocationEqual(self, other);
2160            })
2161       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2162       .def_property_readonly_static(
2163           "current",
2164           [](py::object & /*class*/) {
2165             auto *loc = PyThreadContextEntry::getDefaultLocation();
2166             if (!loc)
2167               throw SetPyError(PyExc_ValueError, "No current Location");
2168             return loc;
2169           },
2170           "Gets the Location bound to the current thread or raises ValueError")
2171       .def_static(
2172           "unknown",
2173           [](DefaultingPyMlirContext context) {
2174             return PyLocation(context->getRef(),
2175                               mlirLocationUnknownGet(context->get()));
2176           },
2177           py::arg("context") = py::none(),
2178           "Gets a Location representing an unknown location")
2179       .def_static(
2180           "callsite",
2181           [](PyLocation callee, const std::vector<PyLocation> &frames,
2182              DefaultingPyMlirContext context) {
2183             if (frames.empty())
2184               throw py::value_error("No caller frames provided");
2185             MlirLocation caller = frames.back().get();
2186             for (const PyLocation &frame :
2187                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2188               caller = mlirLocationCallSiteGet(frame.get(), caller);
2189             return PyLocation(context->getRef(),
2190                               mlirLocationCallSiteGet(callee.get(), caller));
2191           },
2192           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2193           kContextGetCallSiteLocationDocstring)
2194       .def_static(
2195           "file",
2196           [](std::string filename, int line, int col,
2197              DefaultingPyMlirContext context) {
2198             return PyLocation(
2199                 context->getRef(),
2200                 mlirLocationFileLineColGet(
2201                     context->get(), toMlirStringRef(filename), line, col));
2202           },
2203           py::arg("filename"), py::arg("line"), py::arg("col"),
2204           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2205       .def_static(
2206           "fused",
2207           [](const std::vector<PyLocation> &pyLocations, llvm::Optional<PyAttribute> metadata,
2208              DefaultingPyMlirContext context) {
2209             if (pyLocations.empty())
2210               throw py::value_error("No locations provided");
2211             llvm::SmallVector<MlirLocation, 4> locations;
2212             locations.reserve(pyLocations.size());
2213             for (auto &pyLocation : pyLocations)
2214               locations.push_back(pyLocation.get());
2215             MlirLocation location = mlirLocationFusedGet(
2216                 context->get(), locations.size(), locations.data(),
2217                 metadata ? metadata->get() : MlirAttribute{0});
2218             return PyLocation(context->getRef(), location);
2219           },
2220           py::arg("locations"), py::arg("metadata") = py::none(),
2221           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2222       .def_static(
2223           "name",
2224           [](std::string name, llvm::Optional<PyLocation> childLoc,
2225              DefaultingPyMlirContext context) {
2226             return PyLocation(
2227                 context->getRef(),
2228                 mlirLocationNameGet(
2229                     context->get(), toMlirStringRef(name),
2230                     childLoc ? childLoc->get()
2231                              : mlirLocationUnknownGet(context->get())));
2232           },
2233           py::arg("name"), py::arg("childLoc") = py::none(),
2234           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2235       .def_property_readonly(
2236           "context",
2237           [](PyLocation &self) { return self.getContext().getObject(); },
2238           "Context that owns the Location")
2239       .def("__repr__", [](PyLocation &self) {
2240         PyPrintAccumulator printAccum;
2241         mlirLocationPrint(self, printAccum.getCallback(),
2242                           printAccum.getUserData());
2243         return printAccum.join();
2244       });
2245 
2246   //----------------------------------------------------------------------------
2247   // Mapping of Module
2248   //----------------------------------------------------------------------------
2249   py::class_<PyModule>(m, "Module", py::module_local())
2250       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2251       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2252       .def_static(
2253           "parse",
2254           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2255             MlirModule module = mlirModuleCreateParse(
2256                 context->get(), toMlirStringRef(moduleAsm));
2257             // TODO: Rework error reporting once diagnostic engine is exposed
2258             // in C API.
2259             if (mlirModuleIsNull(module)) {
2260               throw SetPyError(
2261                   PyExc_ValueError,
2262                   "Unable to parse module assembly (see diagnostics)");
2263             }
2264             return PyModule::forModule(module).releaseObject();
2265           },
2266           py::arg("asm"), py::arg("context") = py::none(),
2267           kModuleParseDocstring)
2268       .def_static(
2269           "create",
2270           [](DefaultingPyLocation loc) {
2271             MlirModule module = mlirModuleCreateEmpty(loc);
2272             return PyModule::forModule(module).releaseObject();
2273           },
2274           py::arg("loc") = py::none(), "Creates an empty module")
2275       .def_property_readonly(
2276           "context",
2277           [](PyModule &self) { return self.getContext().getObject(); },
2278           "Context that created the Module")
2279       .def_property_readonly(
2280           "operation",
2281           [](PyModule &self) {
2282             return PyOperation::forOperation(self.getContext(),
2283                                              mlirModuleGetOperation(self.get()),
2284                                              self.getRef().releaseObject())
2285                 .releaseObject();
2286           },
2287           "Accesses the module as an operation")
2288       .def_property_readonly(
2289           "body",
2290           [](PyModule &self) {
2291             PyOperationRef moduleOp = PyOperation::forOperation(
2292                 self.getContext(), mlirModuleGetOperation(self.get()),
2293                 self.getRef().releaseObject());
2294             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2295             return returnBlock;
2296           },
2297           "Return the block for this module")
2298       .def(
2299           "dump",
2300           [](PyModule &self) {
2301             mlirOperationDump(mlirModuleGetOperation(self.get()));
2302           },
2303           kDumpDocstring)
2304       .def(
2305           "__str__",
2306           [](py::object self) {
2307             // Defer to the operation's __str__.
2308             return self.attr("operation").attr("__str__")();
2309           },
2310           kOperationStrDunderDocstring);
2311 
2312   //----------------------------------------------------------------------------
2313   // Mapping of Operation.
2314   //----------------------------------------------------------------------------
2315   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2316       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2317                              [](PyOperationBase &self) {
2318                                return self.getOperation().getCapsule();
2319                              })
2320       .def("__eq__",
2321            [](PyOperationBase &self, PyOperationBase &other) {
2322              return &self.getOperation() == &other.getOperation();
2323            })
2324       .def("__eq__",
2325            [](PyOperationBase &self, py::object other) { return false; })
2326       .def("__hash__",
2327            [](PyOperationBase &self) {
2328              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2329            })
2330       .def_property_readonly("attributes",
2331                              [](PyOperationBase &self) {
2332                                return PyOpAttributeMap(
2333                                    self.getOperation().getRef());
2334                              })
2335       .def_property_readonly("operands",
2336                              [](PyOperationBase &self) {
2337                                return PyOpOperandList(
2338                                    self.getOperation().getRef());
2339                              })
2340       .def_property_readonly("regions",
2341                              [](PyOperationBase &self) {
2342                                return PyRegionList(
2343                                    self.getOperation().getRef());
2344                              })
2345       .def_property_readonly(
2346           "results",
2347           [](PyOperationBase &self) {
2348             return PyOpResultList(self.getOperation().getRef());
2349           },
2350           "Returns the list of Operation results.")
2351       .def_property_readonly(
2352           "result",
2353           [](PyOperationBase &self) {
2354             auto &operation = self.getOperation();
2355             auto numResults = mlirOperationGetNumResults(operation);
2356             if (numResults != 1) {
2357               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2358               throw SetPyError(
2359                   PyExc_ValueError,
2360                   Twine("Cannot call .result on operation ") +
2361                       StringRef(name.data, name.length) + " which has " +
2362                       Twine(numResults) +
2363                       " results (it is only valid for operations with a "
2364                       "single result)");
2365             }
2366             return PyOpResult(operation.getRef(),
2367                               mlirOperationGetResult(operation, 0));
2368           },
2369           "Shortcut to get an op result if it has only one (throws an error "
2370           "otherwise).")
2371       .def_property_readonly(
2372           "location",
2373           [](PyOperationBase &self) {
2374             PyOperation &operation = self.getOperation();
2375             return PyLocation(operation.getContext(),
2376                               mlirOperationGetLocation(operation.get()));
2377           },
2378           "Returns the source location the operation was defined or derived "
2379           "from.")
2380       .def(
2381           "__str__",
2382           [](PyOperationBase &self) {
2383             return self.getAsm(/*binary=*/false,
2384                                /*largeElementsLimit=*/llvm::None,
2385                                /*enableDebugInfo=*/false,
2386                                /*prettyDebugInfo=*/false,
2387                                /*printGenericOpForm=*/false,
2388                                /*useLocalScope=*/false,
2389                                /*assumeVerified=*/false);
2390           },
2391           "Returns the assembly form of the operation.")
2392       .def("print", &PyOperationBase::print,
2393            // Careful: Lots of arguments must match up with print method.
2394            py::arg("file") = py::none(), py::arg("binary") = false,
2395            py::arg("large_elements_limit") = py::none(),
2396            py::arg("enable_debug_info") = false,
2397            py::arg("pretty_debug_info") = false,
2398            py::arg("print_generic_op_form") = false,
2399            py::arg("use_local_scope") = false,
2400            py::arg("assume_verified") = false, kOperationPrintDocstring)
2401       .def("get_asm", &PyOperationBase::getAsm,
2402            // Careful: Lots of arguments must match up with get_asm method.
2403            py::arg("binary") = false,
2404            py::arg("large_elements_limit") = py::none(),
2405            py::arg("enable_debug_info") = false,
2406            py::arg("pretty_debug_info") = false,
2407            py::arg("print_generic_op_form") = false,
2408            py::arg("use_local_scope") = false,
2409            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2410       .def(
2411           "verify",
2412           [](PyOperationBase &self) {
2413             return mlirOperationVerify(self.getOperation());
2414           },
2415           "Verify the operation and return true if it passes, false if it "
2416           "fails.")
2417       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2418            "Puts self immediately after the other operation in its parent "
2419            "block.")
2420       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2421            "Puts self immediately before the other operation in its parent "
2422            "block.")
2423       .def(
2424           "detach_from_parent",
2425           [](PyOperationBase &self) {
2426             PyOperation &operation = self.getOperation();
2427             operation.checkValid();
2428             if (!operation.isAttached())
2429               throw py::value_error("Detached operation has no parent.");
2430 
2431             operation.detachFromParent();
2432             return operation.createOpView();
2433           },
2434           "Detaches the operation from its parent block.");
2435 
2436   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2437       .def_static("create", &PyOperation::create, py::arg("name"),
2438                   py::arg("results") = py::none(),
2439                   py::arg("operands") = py::none(),
2440                   py::arg("attributes") = py::none(),
2441                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2442                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2443                   kOperationCreateDocstring)
2444       .def_property_readonly("parent",
2445                              [](PyOperation &self) -> py::object {
2446                                auto parent = self.getParentOperation();
2447                                if (parent)
2448                                  return parent->getObject();
2449                                return py::none();
2450                              })
2451       .def("erase", &PyOperation::erase)
2452       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2453                              &PyOperation::getCapsule)
2454       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2455       .def_property_readonly("name",
2456                              [](PyOperation &self) {
2457                                self.checkValid();
2458                                MlirOperation operation = self.get();
2459                                MlirStringRef name = mlirIdentifierStr(
2460                                    mlirOperationGetName(operation));
2461                                return py::str(name.data, name.length);
2462                              })
2463       .def_property_readonly(
2464           "context",
2465           [](PyOperation &self) {
2466             self.checkValid();
2467             return self.getContext().getObject();
2468           },
2469           "Context that owns the Operation")
2470       .def_property_readonly("opview", &PyOperation::createOpView);
2471 
2472   auto opViewClass =
2473       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2474           .def(py::init<py::object>(), py::arg("operation"))
2475           .def_property_readonly("operation", &PyOpView::getOperationObject)
2476           .def_property_readonly(
2477               "context",
2478               [](PyOpView &self) {
2479                 return self.getOperation().getContext().getObject();
2480               },
2481               "Context that owns the Operation")
2482           .def("__str__", [](PyOpView &self) {
2483             return py::str(self.getOperationObject());
2484           });
2485   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2486   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2487   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2488   opViewClass.attr("build_generic") = classmethod(
2489       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2490       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2491       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2492       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2493       "Builds a specific, generated OpView based on class level attributes.");
2494 
2495   //----------------------------------------------------------------------------
2496   // Mapping of PyRegion.
2497   //----------------------------------------------------------------------------
2498   py::class_<PyRegion>(m, "Region", py::module_local())
2499       .def_property_readonly(
2500           "blocks",
2501           [](PyRegion &self) {
2502             return PyBlockList(self.getParentOperation(), self.get());
2503           },
2504           "Returns a forward-optimized sequence of blocks.")
2505       .def_property_readonly(
2506           "owner",
2507           [](PyRegion &self) {
2508             return self.getParentOperation()->createOpView();
2509           },
2510           "Returns the operation owning this region.")
2511       .def(
2512           "__iter__",
2513           [](PyRegion &self) {
2514             self.checkValid();
2515             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2516             return PyBlockIterator(self.getParentOperation(), firstBlock);
2517           },
2518           "Iterates over blocks in the region.")
2519       .def("__eq__",
2520            [](PyRegion &self, PyRegion &other) {
2521              return self.get().ptr == other.get().ptr;
2522            })
2523       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2524 
2525   //----------------------------------------------------------------------------
2526   // Mapping of PyBlock.
2527   //----------------------------------------------------------------------------
2528   py::class_<PyBlock>(m, "Block", py::module_local())
2529       .def_property_readonly(
2530           "owner",
2531           [](PyBlock &self) {
2532             return self.getParentOperation()->createOpView();
2533           },
2534           "Returns the owning operation of this block.")
2535       .def_property_readonly(
2536           "region",
2537           [](PyBlock &self) {
2538             MlirRegion region = mlirBlockGetParentRegion(self.get());
2539             return PyRegion(self.getParentOperation(), region);
2540           },
2541           "Returns the owning region of this block.")
2542       .def_property_readonly(
2543           "arguments",
2544           [](PyBlock &self) {
2545             return PyBlockArgumentList(self.getParentOperation(), self.get());
2546           },
2547           "Returns a list of block arguments.")
2548       .def_property_readonly(
2549           "operations",
2550           [](PyBlock &self) {
2551             return PyOperationList(self.getParentOperation(), self.get());
2552           },
2553           "Returns a forward-optimized sequence of operations.")
2554       .def_static(
2555           "create_at_start",
2556           [](PyRegion &parent, py::list pyArgTypes) {
2557             parent.checkValid();
2558             llvm::SmallVector<MlirType, 4> argTypes;
2559             argTypes.reserve(pyArgTypes.size());
2560             for (auto &pyArg : pyArgTypes) {
2561               argTypes.push_back(pyArg.cast<PyType &>());
2562             }
2563 
2564             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2565             mlirRegionInsertOwnedBlock(parent, 0, block);
2566             return PyBlock(parent.getParentOperation(), block);
2567           },
2568           py::arg("parent"), py::arg("arg_types") = py::list(),
2569           "Creates and returns a new Block at the beginning of the given "
2570           "region (with given argument types).")
2571       .def(
2572           "create_before",
2573           [](PyBlock &self, py::args pyArgTypes) {
2574             self.checkValid();
2575             llvm::SmallVector<MlirType, 4> argTypes;
2576             argTypes.reserve(pyArgTypes.size());
2577             for (auto &pyArg : pyArgTypes) {
2578               argTypes.push_back(pyArg.cast<PyType &>());
2579             }
2580 
2581             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2582             MlirRegion region = mlirBlockGetParentRegion(self.get());
2583             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2584             return PyBlock(self.getParentOperation(), block);
2585           },
2586           "Creates and returns a new Block before this block "
2587           "(with given argument types).")
2588       .def(
2589           "create_after",
2590           [](PyBlock &self, py::args pyArgTypes) {
2591             self.checkValid();
2592             llvm::SmallVector<MlirType, 4> argTypes;
2593             argTypes.reserve(pyArgTypes.size());
2594             for (auto &pyArg : pyArgTypes) {
2595               argTypes.push_back(pyArg.cast<PyType &>());
2596             }
2597 
2598             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
2599             MlirRegion region = mlirBlockGetParentRegion(self.get());
2600             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2601             return PyBlock(self.getParentOperation(), block);
2602           },
2603           "Creates and returns a new Block after this block "
2604           "(with given argument types).")
2605       .def(
2606           "__iter__",
2607           [](PyBlock &self) {
2608             self.checkValid();
2609             MlirOperation firstOperation =
2610                 mlirBlockGetFirstOperation(self.get());
2611             return PyOperationIterator(self.getParentOperation(),
2612                                        firstOperation);
2613           },
2614           "Iterates over operations in the block.")
2615       .def("__eq__",
2616            [](PyBlock &self, PyBlock &other) {
2617              return self.get().ptr == other.get().ptr;
2618            })
2619       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2620       .def(
2621           "__str__",
2622           [](PyBlock &self) {
2623             self.checkValid();
2624             PyPrintAccumulator printAccum;
2625             mlirBlockPrint(self.get(), printAccum.getCallback(),
2626                            printAccum.getUserData());
2627             return printAccum.join();
2628           },
2629           "Returns the assembly form of the block.")
2630       .def(
2631           "append",
2632           [](PyBlock &self, PyOperationBase &operation) {
2633             if (operation.getOperation().isAttached())
2634               operation.getOperation().detachFromParent();
2635 
2636             MlirOperation mlirOperation = operation.getOperation().get();
2637             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2638             operation.getOperation().setAttached(
2639                 self.getParentOperation().getObject());
2640           },
2641           py::arg("operation"),
2642           "Appends an operation to this block. If the operation is currently "
2643           "in another block, it will be moved.");
2644 
2645   //----------------------------------------------------------------------------
2646   // Mapping of PyInsertionPoint.
2647   //----------------------------------------------------------------------------
2648 
2649   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2650       .def(py::init<PyBlock &>(), py::arg("block"),
2651            "Inserts after the last operation but still inside the block.")
2652       .def("__enter__", &PyInsertionPoint::contextEnter)
2653       .def("__exit__", &PyInsertionPoint::contextExit)
2654       .def_property_readonly_static(
2655           "current",
2656           [](py::object & /*class*/) {
2657             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2658             if (!ip)
2659               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2660             return ip;
2661           },
2662           "Gets the InsertionPoint bound to the current thread or raises "
2663           "ValueError if none has been set")
2664       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2665            "Inserts before a referenced operation.")
2666       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2667                   py::arg("block"), "Inserts at the beginning of the block.")
2668       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2669                   py::arg("block"), "Inserts before the block terminator.")
2670       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2671            "Inserts an operation.")
2672       .def_property_readonly(
2673           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2674           "Returns the block that this InsertionPoint points to.");
2675 
2676   //----------------------------------------------------------------------------
2677   // Mapping of PyAttribute.
2678   //----------------------------------------------------------------------------
2679   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2680       // Delegate to the PyAttribute copy constructor, which will also lifetime
2681       // extend the backing context which owns the MlirAttribute.
2682       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2683            "Casts the passed attribute to the generic Attribute")
2684       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2685                              &PyAttribute::getCapsule)
2686       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2687       .def_static(
2688           "parse",
2689           [](std::string attrSpec, DefaultingPyMlirContext context) {
2690             MlirAttribute type = mlirAttributeParseGet(
2691                 context->get(), toMlirStringRef(attrSpec));
2692             // TODO: Rework error reporting once diagnostic engine is exposed
2693             // in C API.
2694             if (mlirAttributeIsNull(type)) {
2695               throw SetPyError(PyExc_ValueError,
2696                                Twine("Unable to parse attribute: '") +
2697                                    attrSpec + "'");
2698             }
2699             return PyAttribute(context->getRef(), type);
2700           },
2701           py::arg("asm"), py::arg("context") = py::none(),
2702           "Parses an attribute from an assembly form")
2703       .def_property_readonly(
2704           "context",
2705           [](PyAttribute &self) { return self.getContext().getObject(); },
2706           "Context that owns the Attribute")
2707       .def_property_readonly("type",
2708                              [](PyAttribute &self) {
2709                                return PyType(self.getContext()->getRef(),
2710                                              mlirAttributeGetType(self));
2711                              })
2712       .def(
2713           "get_named",
2714           [](PyAttribute &self, std::string name) {
2715             return PyNamedAttribute(self, std::move(name));
2716           },
2717           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2718       .def("__eq__",
2719            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2720       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2721       .def("__hash__",
2722            [](PyAttribute &self) {
2723              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2724            })
2725       .def(
2726           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2727           kDumpDocstring)
2728       .def(
2729           "__str__",
2730           [](PyAttribute &self) {
2731             PyPrintAccumulator printAccum;
2732             mlirAttributePrint(self, printAccum.getCallback(),
2733                                printAccum.getUserData());
2734             return printAccum.join();
2735           },
2736           "Returns the assembly form of the Attribute.")
2737       .def("__repr__", [](PyAttribute &self) {
2738         // Generally, assembly formats are not printed for __repr__ because
2739         // this can cause exceptionally long debug output and exceptions.
2740         // However, attribute values are generally considered useful and are
2741         // printed. This may need to be re-evaluated if debug dumps end up
2742         // being excessive.
2743         PyPrintAccumulator printAccum;
2744         printAccum.parts.append("Attribute(");
2745         mlirAttributePrint(self, printAccum.getCallback(),
2746                            printAccum.getUserData());
2747         printAccum.parts.append(")");
2748         return printAccum.join();
2749       });
2750 
2751   //----------------------------------------------------------------------------
2752   // Mapping of PyNamedAttribute
2753   //----------------------------------------------------------------------------
2754   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2755       .def("__repr__",
2756            [](PyNamedAttribute &self) {
2757              PyPrintAccumulator printAccum;
2758              printAccum.parts.append("NamedAttribute(");
2759              printAccum.parts.append(
2760                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2761                          mlirIdentifierStr(self.namedAttr.name).length));
2762              printAccum.parts.append("=");
2763              mlirAttributePrint(self.namedAttr.attribute,
2764                                 printAccum.getCallback(),
2765                                 printAccum.getUserData());
2766              printAccum.parts.append(")");
2767              return printAccum.join();
2768            })
2769       .def_property_readonly(
2770           "name",
2771           [](PyNamedAttribute &self) {
2772             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2773                            mlirIdentifierStr(self.namedAttr.name).length);
2774           },
2775           "The name of the NamedAttribute binding")
2776       .def_property_readonly(
2777           "attr",
2778           [](PyNamedAttribute &self) {
2779             // TODO: When named attribute is removed/refactored, also remove
2780             // this constructor (it does an inefficient table lookup).
2781             auto contextRef = PyMlirContext::forContext(
2782                 mlirAttributeGetContext(self.namedAttr.attribute));
2783             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2784           },
2785           py::keep_alive<0, 1>(),
2786           "The underlying generic attribute of the NamedAttribute binding");
2787 
2788   //----------------------------------------------------------------------------
2789   // Mapping of PyType.
2790   //----------------------------------------------------------------------------
2791   py::class_<PyType>(m, "Type", py::module_local())
2792       // Delegate to the PyType copy constructor, which will also lifetime
2793       // extend the backing context which owns the MlirType.
2794       .def(py::init<PyType &>(), py::arg("cast_from_type"),
2795            "Casts the passed type to the generic Type")
2796       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2797       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2798       .def_static(
2799           "parse",
2800           [](std::string typeSpec, DefaultingPyMlirContext context) {
2801             MlirType type =
2802                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2803             // TODO: Rework error reporting once diagnostic engine is exposed
2804             // in C API.
2805             if (mlirTypeIsNull(type)) {
2806               throw SetPyError(PyExc_ValueError,
2807                                Twine("Unable to parse type: '") + typeSpec +
2808                                    "'");
2809             }
2810             return PyType(context->getRef(), type);
2811           },
2812           py::arg("asm"), py::arg("context") = py::none(),
2813           kContextParseTypeDocstring)
2814       .def_property_readonly(
2815           "context", [](PyType &self) { return self.getContext().getObject(); },
2816           "Context that owns the Type")
2817       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2818       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2819       .def("__hash__",
2820            [](PyType &self) {
2821              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2822            })
2823       .def(
2824           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2825       .def(
2826           "__str__",
2827           [](PyType &self) {
2828             PyPrintAccumulator printAccum;
2829             mlirTypePrint(self, printAccum.getCallback(),
2830                           printAccum.getUserData());
2831             return printAccum.join();
2832           },
2833           "Returns the assembly form of the type.")
2834       .def("__repr__", [](PyType &self) {
2835         // Generally, assembly formats are not printed for __repr__ because
2836         // this can cause exceptionally long debug output and exceptions.
2837         // However, types are an exception as they typically have compact
2838         // assembly forms and printing them is useful.
2839         PyPrintAccumulator printAccum;
2840         printAccum.parts.append("Type(");
2841         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2842         printAccum.parts.append(")");
2843         return printAccum.join();
2844       });
2845 
2846   //----------------------------------------------------------------------------
2847   // Mapping of Value.
2848   //----------------------------------------------------------------------------
2849   py::class_<PyValue>(m, "Value", py::module_local())
2850       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
2851       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
2852       .def_property_readonly(
2853           "context",
2854           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2855           "Context in which the value lives.")
2856       .def(
2857           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2858           kDumpDocstring)
2859       .def_property_readonly(
2860           "owner",
2861           [](PyValue &self) {
2862             assert(mlirOperationEqual(self.getParentOperation()->get(),
2863                                       mlirOpResultGetOwner(self.get())) &&
2864                    "expected the owner of the value in Python to match that in "
2865                    "the IR");
2866             return self.getParentOperation().getObject();
2867           })
2868       .def("__eq__",
2869            [](PyValue &self, PyValue &other) {
2870              return self.get().ptr == other.get().ptr;
2871            })
2872       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2873       .def("__hash__",
2874            [](PyValue &self) {
2875              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2876            })
2877       .def(
2878           "__str__",
2879           [](PyValue &self) {
2880             PyPrintAccumulator printAccum;
2881             printAccum.parts.append("Value(");
2882             mlirValuePrint(self.get(), printAccum.getCallback(),
2883                            printAccum.getUserData());
2884             printAccum.parts.append(")");
2885             return printAccum.join();
2886           },
2887           kValueDunderStrDocstring)
2888       .def_property_readonly("type", [](PyValue &self) {
2889         return PyType(self.getParentOperation()->getContext(),
2890                       mlirValueGetType(self.get()));
2891       });
2892   PyBlockArgument::bind(m);
2893   PyOpResult::bind(m);
2894 
2895   //----------------------------------------------------------------------------
2896   // Mapping of SymbolTable.
2897   //----------------------------------------------------------------------------
2898   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
2899       .def(py::init<PyOperationBase &>())
2900       .def("__getitem__", &PySymbolTable::dunderGetItem)
2901       .def("insert", &PySymbolTable::insert, py::arg("operation"))
2902       .def("erase", &PySymbolTable::erase, py::arg("operation"))
2903       .def("__delitem__", &PySymbolTable::dunderDel)
2904       .def("__contains__",
2905            [](PySymbolTable &table, const std::string &name) {
2906              return !mlirOperationIsNull(mlirSymbolTableLookup(
2907                  table, mlirStringRefCreate(name.data(), name.length())));
2908            })
2909       // Static helpers.
2910       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
2911                   py::arg("symbol"), py::arg("name"))
2912       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
2913                   py::arg("symbol"))
2914       .def_static("get_visibility", &PySymbolTable::getVisibility,
2915                   py::arg("symbol"))
2916       .def_static("set_visibility", &PySymbolTable::setVisibility,
2917                   py::arg("symbol"), py::arg("visibility"))
2918       .def_static("replace_all_symbol_uses",
2919                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
2920                   py::arg("new_symbol"), py::arg("from_op"))
2921       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
2922                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
2923                   py::arg("callback"));
2924 
2925   // Container bindings.
2926   PyBlockArgumentList::bind(m);
2927   PyBlockIterator::bind(m);
2928   PyBlockList::bind(m);
2929   PyOperationIterator::bind(m);
2930   PyOperationList::bind(m);
2931   PyOpAttributeMap::bind(m);
2932   PyOpOperandList::bind(m);
2933   PyOpResultList::bind(m);
2934   PyRegionIterator::bind(m);
2935   PyRegionList::bind(m);
2936 
2937   // Debug bindings.
2938   PyGlobalDebugFlag::bind(m);
2939 }
2940