1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Debug.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/Registration.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include <pybind11/stl.h>
23 
24 #include <utility>
25 
26 namespace py = pybind11;
27 using namespace mlir;
28 using namespace mlir::python;
29 
30 using llvm::SmallVector;
31 using llvm::StringRef;
32 using llvm::Twine;
33 
34 //------------------------------------------------------------------------------
35 // Docstrings (trivial, non-duplicated docstrings are included inline).
36 //------------------------------------------------------------------------------
37 
38 static const char kContextParseTypeDocstring[] =
39     R"(Parses the assembly form of a type.
40 
41 Returns a Type object or raises a ValueError if the type cannot be parsed.
42 
43 See also: https://mlir.llvm.org/docs/LangRef/#type-system
44 )";
45 
46 static const char kContextGetCallSiteLocationDocstring[] =
47     R"(Gets a Location representing a caller and callsite)";
48 
49 static const char kContextGetFileLocationDocstring[] =
50     R"(Gets a Location representing a file, line and column)";
51 
52 static const char kContextGetFusedLocationDocstring[] =
53     R"(Gets a Location representing a fused location with optional metadata)";
54 
55 static const char kContextGetNameLocationDocString[] =
56     R"(Gets a Location representing a named location with optional child location)";
57 
58 static const char kModuleParseDocstring[] =
59     R"(Parses a module's assembly format from a string.
60 
61 Returns a new MlirModule or raises a ValueError if the parsing fails.
62 
63 See also: https://mlir.llvm.org/docs/LangRef/
64 )";
65 
66 static const char kOperationCreateDocstring[] =
67     R"(Creates a new operation.
68 
69 Args:
70   name: Operation name (e.g. "dialect.operation").
71   results: Sequence of Type representing op result types.
72   attributes: Dict of str:Attribute.
73   successors: List of Block for the operation's successors.
74   regions: Number of regions to create.
75   location: A Location object (defaults to resolve from context manager).
76   ip: An InsertionPoint (defaults to resolve from context manager or set to
77     False to disable insertion, even with an insertion point set in the
78     context manager).
79 Returns:
80   A new "detached" Operation object. Detached operations can be added
81   to blocks, which causes them to become "attached."
82 )";
83 
84 static const char kOperationPrintDocstring[] =
85     R"(Prints the assembly form of the operation to a file like object.
86 
87 Args:
88   file: The file like object to write to. Defaults to sys.stdout.
89   binary: Whether to write bytes (True) or str (False). Defaults to False.
90   large_elements_limit: Whether to elide elements attributes above this
91     number of elements. Defaults to None (no limit).
92   enable_debug_info: Whether to print debug/location information. Defaults
93     to False.
94   pretty_debug_info: Whether to format debug information for easier reading
95     by a human (warning: the result is unparseable).
96   print_generic_op_form: Whether to print the generic assembly forms of all
97     ops. Defaults to False.
98   use_local_Scope: Whether to print in a way that is more optimized for
99     multi-threaded access but may not be consistent with how the overall
100     module prints.
101   assume_verified: By default, if not printing generic form, the verifier
102     will be run and if it fails, generic form will be printed with a comment
103     about failed verification. While a reasonable default for interactive use,
104     for systematic use, it is often better for the caller to verify explicitly
105     and report failures in a more robust fashion. Set this to True if doing this
106     in order to avoid running a redundant verification. If the IR is actually
107     invalid, behavior is undefined.
108 )";
109 
110 static const char kOperationGetAsmDocstring[] =
111     R"(Gets the assembly form of the operation with all options available.
112 
113 Args:
114   binary: Whether to return a bytes (True) or str (False) object. Defaults to
115     False.
116   ... others ...: See the print() method for common keyword arguments for
117     configuring the printout.
118 Returns:
119   Either a bytes or str object, depending on the setting of the 'binary'
120   argument.
121 )";
122 
123 static const char kOperationStrDunderDocstring[] =
124     R"(Gets the assembly form of the operation with default options.
125 
126 If more advanced control over the assembly formatting or I/O options is needed,
127 use the dedicated print or get_asm method, which supports keyword arguments to
128 customize behavior.
129 )";
130 
131 static const char kDumpDocstring[] =
132     R"(Dumps a debug representation of the object to stderr.)";
133 
134 static const char kAppendBlockDocstring[] =
135     R"(Appends a new block, with argument types as positional args.
136 
137 Returns:
138   The created block.
139 )";
140 
141 static const char kValueDunderStrDocstring[] =
142     R"(Returns the string form of the value.
143 
144 If the value is a block argument, this is the assembly form of its type and the
145 position in the argument list. If the value is an operation result, this is
146 equivalent to printing the operation that produced it.
147 )";
148 
149 //------------------------------------------------------------------------------
150 // Utilities.
151 //------------------------------------------------------------------------------
152 
153 /// Helper for creating an @classmethod.
154 template <class Func, typename... Args>
155 py::object classmethod(Func f, Args... args) {
156   py::object cf = py::cpp_function(f, args...);
157   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
158 }
159 
160 static py::object
161 createCustomDialectWrapper(const std::string &dialectNamespace,
162                            py::object dialectDescriptor) {
163   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
164   if (!dialectClass) {
165     // Use the base class.
166     return py::cast(PyDialect(std::move(dialectDescriptor)));
167   }
168 
169   // Create the custom implementation.
170   return (*dialectClass)(std::move(dialectDescriptor));
171 }
172 
173 static MlirStringRef toMlirStringRef(const std::string &s) {
174   return mlirStringRefCreate(s.data(), s.size());
175 }
176 
177 /// Wrapper for the global LLVM debugging flag.
178 struct PyGlobalDebugFlag {
179   static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); }
180 
181   static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); }
182 
183   static void bind(py::module &m) {
184     // Debug flags.
185     py::class_<PyGlobalDebugFlag>(m, "_GlobalDebug", py::module_local())
186         .def_property_static("flag", &PyGlobalDebugFlag::get,
187                              &PyGlobalDebugFlag::set, "LLVM-wide debug flag");
188   }
189 };
190 
191 //------------------------------------------------------------------------------
192 // Collections.
193 //------------------------------------------------------------------------------
194 
195 namespace {
196 
197 class PyRegionIterator {
198 public:
199   PyRegionIterator(PyOperationRef operation)
200       : operation(std::move(operation)) {}
201 
202   PyRegionIterator &dunderIter() { return *this; }
203 
204   PyRegion dunderNext() {
205     operation->checkValid();
206     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
207       throw py::stop_iteration();
208     }
209     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
210     return PyRegion(operation, region);
211   }
212 
213   static void bind(py::module &m) {
214     py::class_<PyRegionIterator>(m, "RegionIterator", py::module_local())
215         .def("__iter__", &PyRegionIterator::dunderIter)
216         .def("__next__", &PyRegionIterator::dunderNext);
217   }
218 
219 private:
220   PyOperationRef operation;
221   int nextIndex = 0;
222 };
223 
224 /// Regions of an op are fixed length and indexed numerically so are represented
225 /// with a sequence-like container.
226 class PyRegionList {
227 public:
228   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
229 
230   intptr_t dunderLen() {
231     operation->checkValid();
232     return mlirOperationGetNumRegions(operation->get());
233   }
234 
235   PyRegion dunderGetItem(intptr_t index) {
236     // dunderLen checks validity.
237     if (index < 0 || index >= dunderLen()) {
238       throw SetPyError(PyExc_IndexError,
239                        "attempt to access out of bounds region");
240     }
241     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
242     return PyRegion(operation, region);
243   }
244 
245   static void bind(py::module &m) {
246     py::class_<PyRegionList>(m, "RegionSequence", py::module_local())
247         .def("__len__", &PyRegionList::dunderLen)
248         .def("__getitem__", &PyRegionList::dunderGetItem);
249   }
250 
251 private:
252   PyOperationRef operation;
253 };
254 
255 class PyBlockIterator {
256 public:
257   PyBlockIterator(PyOperationRef operation, MlirBlock next)
258       : operation(std::move(operation)), next(next) {}
259 
260   PyBlockIterator &dunderIter() { return *this; }
261 
262   PyBlock dunderNext() {
263     operation->checkValid();
264     if (mlirBlockIsNull(next)) {
265       throw py::stop_iteration();
266     }
267 
268     PyBlock returnBlock(operation, next);
269     next = mlirBlockGetNextInRegion(next);
270     return returnBlock;
271   }
272 
273   static void bind(py::module &m) {
274     py::class_<PyBlockIterator>(m, "BlockIterator", py::module_local())
275         .def("__iter__", &PyBlockIterator::dunderIter)
276         .def("__next__", &PyBlockIterator::dunderNext);
277   }
278 
279 private:
280   PyOperationRef operation;
281   MlirBlock next;
282 };
283 
284 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
285 /// we present them as a more full-featured list-like container but optimize
286 /// it for forward iteration. Blocks are always owned by a region.
287 class PyBlockList {
288 public:
289   PyBlockList(PyOperationRef operation, MlirRegion region)
290       : operation(std::move(operation)), region(region) {}
291 
292   PyBlockIterator dunderIter() {
293     operation->checkValid();
294     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
295   }
296 
297   intptr_t dunderLen() {
298     operation->checkValid();
299     intptr_t count = 0;
300     MlirBlock block = mlirRegionGetFirstBlock(region);
301     while (!mlirBlockIsNull(block)) {
302       count += 1;
303       block = mlirBlockGetNextInRegion(block);
304     }
305     return count;
306   }
307 
308   PyBlock dunderGetItem(intptr_t index) {
309     operation->checkValid();
310     if (index < 0) {
311       throw SetPyError(PyExc_IndexError,
312                        "attempt to access out of bounds block");
313     }
314     MlirBlock block = mlirRegionGetFirstBlock(region);
315     while (!mlirBlockIsNull(block)) {
316       if (index == 0) {
317         return PyBlock(operation, block);
318       }
319       block = mlirBlockGetNextInRegion(block);
320       index -= 1;
321     }
322     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
323   }
324 
325   PyBlock appendBlock(const py::args &pyArgTypes) {
326     operation->checkValid();
327     llvm::SmallVector<MlirType, 4> argTypes;
328     llvm::SmallVector<MlirLocation, 4> argLocs;
329     argTypes.reserve(pyArgTypes.size());
330     argLocs.reserve(pyArgTypes.size());
331     for (auto &pyArg : pyArgTypes) {
332       argTypes.push_back(pyArg.cast<PyType &>());
333       // TODO: Pass in a proper location here.
334       argLocs.push_back(
335           mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
336     }
337 
338     MlirBlock block =
339         mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data());
340     mlirRegionAppendOwnedBlock(region, block);
341     return PyBlock(operation, block);
342   }
343 
344   static void bind(py::module &m) {
345     py::class_<PyBlockList>(m, "BlockList", py::module_local())
346         .def("__getitem__", &PyBlockList::dunderGetItem)
347         .def("__iter__", &PyBlockList::dunderIter)
348         .def("__len__", &PyBlockList::dunderLen)
349         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
350   }
351 
352 private:
353   PyOperationRef operation;
354   MlirRegion region;
355 };
356 
357 class PyOperationIterator {
358 public:
359   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
360       : parentOperation(std::move(parentOperation)), next(next) {}
361 
362   PyOperationIterator &dunderIter() { return *this; }
363 
364   py::object dunderNext() {
365     parentOperation->checkValid();
366     if (mlirOperationIsNull(next)) {
367       throw py::stop_iteration();
368     }
369 
370     PyOperationRef returnOperation =
371         PyOperation::forOperation(parentOperation->getContext(), next);
372     next = mlirOperationGetNextInBlock(next);
373     return returnOperation->createOpView();
374   }
375 
376   static void bind(py::module &m) {
377     py::class_<PyOperationIterator>(m, "OperationIterator", py::module_local())
378         .def("__iter__", &PyOperationIterator::dunderIter)
379         .def("__next__", &PyOperationIterator::dunderNext);
380   }
381 
382 private:
383   PyOperationRef parentOperation;
384   MlirOperation next;
385 };
386 
387 /// Operations are exposed by the C-API as a forward-only linked list. In
388 /// Python, we present them as a more full-featured list-like container but
389 /// optimize it for forward iteration. Iterable operations are always owned
390 /// by a block.
391 class PyOperationList {
392 public:
393   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
394       : parentOperation(std::move(parentOperation)), block(block) {}
395 
396   PyOperationIterator dunderIter() {
397     parentOperation->checkValid();
398     return PyOperationIterator(parentOperation,
399                                mlirBlockGetFirstOperation(block));
400   }
401 
402   intptr_t dunderLen() {
403     parentOperation->checkValid();
404     intptr_t count = 0;
405     MlirOperation childOp = mlirBlockGetFirstOperation(block);
406     while (!mlirOperationIsNull(childOp)) {
407       count += 1;
408       childOp = mlirOperationGetNextInBlock(childOp);
409     }
410     return count;
411   }
412 
413   py::object dunderGetItem(intptr_t index) {
414     parentOperation->checkValid();
415     if (index < 0) {
416       throw SetPyError(PyExc_IndexError,
417                        "attempt to access out of bounds operation");
418     }
419     MlirOperation childOp = mlirBlockGetFirstOperation(block);
420     while (!mlirOperationIsNull(childOp)) {
421       if (index == 0) {
422         return PyOperation::forOperation(parentOperation->getContext(), childOp)
423             ->createOpView();
424       }
425       childOp = mlirOperationGetNextInBlock(childOp);
426       index -= 1;
427     }
428     throw SetPyError(PyExc_IndexError,
429                      "attempt to access out of bounds operation");
430   }
431 
432   static void bind(py::module &m) {
433     py::class_<PyOperationList>(m, "OperationList", py::module_local())
434         .def("__getitem__", &PyOperationList::dunderGetItem)
435         .def("__iter__", &PyOperationList::dunderIter)
436         .def("__len__", &PyOperationList::dunderLen);
437   }
438 
439 private:
440   PyOperationRef parentOperation;
441   MlirBlock block;
442 };
443 
444 } // namespace
445 
446 //------------------------------------------------------------------------------
447 // PyMlirContext
448 //------------------------------------------------------------------------------
449 
450 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
451   py::gil_scoped_acquire acquire;
452   auto &liveContexts = getLiveContexts();
453   liveContexts[context.ptr] = this;
454 }
455 
456 PyMlirContext::~PyMlirContext() {
457   // Note that the only public way to construct an instance is via the
458   // forContext method, which always puts the associated handle into
459   // liveContexts.
460   py::gil_scoped_acquire acquire;
461   getLiveContexts().erase(context.ptr);
462   mlirContextDestroy(context);
463 }
464 
465 py::object PyMlirContext::getCapsule() {
466   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
467 }
468 
469 py::object PyMlirContext::createFromCapsule(py::object capsule) {
470   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
471   if (mlirContextIsNull(rawContext))
472     throw py::error_already_set();
473   return forContext(rawContext).releaseObject();
474 }
475 
476 PyMlirContext *PyMlirContext::createNewContextForInit() {
477   MlirContext context = mlirContextCreate();
478   mlirRegisterAllDialects(context);
479   return new PyMlirContext(context);
480 }
481 
482 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
483   py::gil_scoped_acquire acquire;
484   auto &liveContexts = getLiveContexts();
485   auto it = liveContexts.find(context.ptr);
486   if (it == liveContexts.end()) {
487     // Create.
488     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
489     py::object pyRef = py::cast(unownedContextWrapper);
490     assert(pyRef && "cast to py::object failed");
491     liveContexts[context.ptr] = unownedContextWrapper;
492     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
493   }
494   // Use existing.
495   py::object pyRef = py::cast(it->second);
496   return PyMlirContextRef(it->second, std::move(pyRef));
497 }
498 
499 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
500   static LiveContextMap liveContexts;
501   return liveContexts;
502 }
503 
504 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
505 
506 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
507 
508 size_t PyMlirContext::clearLiveOperations() {
509   for (auto &op : liveOperations)
510     op.second.second->setInvalid();
511   size_t numInvalidated = liveOperations.size();
512   liveOperations.clear();
513   return numInvalidated;
514 }
515 
516 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
517 
518 pybind11::object PyMlirContext::contextEnter() {
519   return PyThreadContextEntry::pushContext(*this);
520 }
521 
522 void PyMlirContext::contextExit(const pybind11::object &excType,
523                                 const pybind11::object &excVal,
524                                 const pybind11::object &excTb) {
525   PyThreadContextEntry::popContext(*this);
526 }
527 
528 py::object PyMlirContext::attachDiagnosticHandler(py::object callback) {
529   // Note that ownership is transferred to the delete callback below by way of
530   // an explicit inc_ref (borrow).
531   PyDiagnosticHandler *pyHandler =
532       new PyDiagnosticHandler(get(), std::move(callback));
533   py::object pyHandlerObject =
534       py::cast(pyHandler, py::return_value_policy::take_ownership);
535   pyHandlerObject.inc_ref();
536 
537   // In these C callbacks, the userData is a PyDiagnosticHandler* that is
538   // guaranteed to be known to pybind.
539   auto handlerCallback =
540       +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult {
541     PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic);
542     py::object pyDiagnosticObject =
543         py::cast(pyDiagnostic, py::return_value_policy::take_ownership);
544 
545     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
546     bool result = false;
547     {
548       // Since this can be called from arbitrary C++ contexts, always get the
549       // gil.
550       py::gil_scoped_acquire gil;
551       try {
552         result = py::cast<bool>(pyHandler->callback(pyDiagnostic));
553       } catch (std::exception &e) {
554         fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n",
555                 e.what());
556         pyHandler->hadError = true;
557       }
558     }
559 
560     pyDiagnostic->invalidate();
561     return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure();
562   };
563   auto deleteCallback = +[](void *userData) {
564     auto *pyHandler = static_cast<PyDiagnosticHandler *>(userData);
565     assert(pyHandler->registeredID && "handler is not registered");
566     pyHandler->registeredID.reset();
567 
568     // Decrement reference, balancing the inc_ref() above.
569     py::object pyHandlerObject =
570         py::cast(pyHandler, py::return_value_policy::reference);
571     pyHandlerObject.dec_ref();
572   };
573 
574   pyHandler->registeredID = mlirContextAttachDiagnosticHandler(
575       get(), handlerCallback, static_cast<void *>(pyHandler), deleteCallback);
576   return pyHandlerObject;
577 }
578 
579 PyMlirContext &DefaultingPyMlirContext::resolve() {
580   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
581   if (!context) {
582     throw SetPyError(
583         PyExc_RuntimeError,
584         "An MLIR function requires a Context but none was provided in the call "
585         "or from the surrounding environment. Either pass to the function with "
586         "a 'context=' argument or establish a default using 'with Context():'");
587   }
588   return *context;
589 }
590 
591 //------------------------------------------------------------------------------
592 // PyThreadContextEntry management
593 //------------------------------------------------------------------------------
594 
595 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
596   static thread_local std::vector<PyThreadContextEntry> stack;
597   return stack;
598 }
599 
600 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
601   auto &stack = getStack();
602   if (stack.empty())
603     return nullptr;
604   return &stack.back();
605 }
606 
607 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
608                                 py::object insertionPoint,
609                                 py::object location) {
610   auto &stack = getStack();
611   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
612                      std::move(location));
613   // If the new stack has more than one entry and the context of the new top
614   // entry matches the previous, copy the insertionPoint and location from the
615   // previous entry if missing from the new top entry.
616   if (stack.size() > 1) {
617     auto &prev = *(stack.rbegin() + 1);
618     auto &current = stack.back();
619     if (current.context.is(prev.context)) {
620       // Default non-context objects from the previous entry.
621       if (!current.insertionPoint)
622         current.insertionPoint = prev.insertionPoint;
623       if (!current.location)
624         current.location = prev.location;
625     }
626   }
627 }
628 
629 PyMlirContext *PyThreadContextEntry::getContext() {
630   if (!context)
631     return nullptr;
632   return py::cast<PyMlirContext *>(context);
633 }
634 
635 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
636   if (!insertionPoint)
637     return nullptr;
638   return py::cast<PyInsertionPoint *>(insertionPoint);
639 }
640 
641 PyLocation *PyThreadContextEntry::getLocation() {
642   if (!location)
643     return nullptr;
644   return py::cast<PyLocation *>(location);
645 }
646 
647 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
648   auto *tos = getTopOfStack();
649   return tos ? tos->getContext() : nullptr;
650 }
651 
652 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
653   auto *tos = getTopOfStack();
654   return tos ? tos->getInsertionPoint() : nullptr;
655 }
656 
657 PyLocation *PyThreadContextEntry::getDefaultLocation() {
658   auto *tos = getTopOfStack();
659   return tos ? tos->getLocation() : nullptr;
660 }
661 
662 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
663   py::object contextObj = py::cast(context);
664   push(FrameKind::Context, /*context=*/contextObj,
665        /*insertionPoint=*/py::object(),
666        /*location=*/py::object());
667   return contextObj;
668 }
669 
670 void PyThreadContextEntry::popContext(PyMlirContext &context) {
671   auto &stack = getStack();
672   if (stack.empty())
673     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
674   auto &tos = stack.back();
675   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
676     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
677   stack.pop_back();
678 }
679 
680 py::object
681 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
682   py::object contextObj =
683       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
684   py::object insertionPointObj = py::cast(insertionPoint);
685   push(FrameKind::InsertionPoint,
686        /*context=*/contextObj,
687        /*insertionPoint=*/insertionPointObj,
688        /*location=*/py::object());
689   return insertionPointObj;
690 }
691 
692 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
693   auto &stack = getStack();
694   if (stack.empty())
695     throw SetPyError(PyExc_RuntimeError,
696                      "Unbalanced InsertionPoint enter/exit");
697   auto &tos = stack.back();
698   if (tos.frameKind != FrameKind::InsertionPoint &&
699       tos.getInsertionPoint() != &insertionPoint)
700     throw SetPyError(PyExc_RuntimeError,
701                      "Unbalanced InsertionPoint enter/exit");
702   stack.pop_back();
703 }
704 
705 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
706   py::object contextObj = location.getContext().getObject();
707   py::object locationObj = py::cast(location);
708   push(FrameKind::Location, /*context=*/contextObj,
709        /*insertionPoint=*/py::object(),
710        /*location=*/locationObj);
711   return locationObj;
712 }
713 
714 void PyThreadContextEntry::popLocation(PyLocation &location) {
715   auto &stack = getStack();
716   if (stack.empty())
717     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
718   auto &tos = stack.back();
719   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
720     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
721   stack.pop_back();
722 }
723 
724 //------------------------------------------------------------------------------
725 // PyDiagnostic*
726 //------------------------------------------------------------------------------
727 
728 void PyDiagnostic::invalidate() {
729   valid = false;
730   if (materializedNotes) {
731     for (auto &noteObject : *materializedNotes) {
732       PyDiagnostic *note = py::cast<PyDiagnostic *>(noteObject);
733       note->invalidate();
734     }
735   }
736 }
737 
738 PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context,
739                                          py::object callback)
740     : context(context), callback(std::move(callback)) {}
741 
742 PyDiagnosticHandler::~PyDiagnosticHandler() = default;
743 
744 void PyDiagnosticHandler::detach() {
745   if (!registeredID)
746     return;
747   MlirDiagnosticHandlerID localID = *registeredID;
748   mlirContextDetachDiagnosticHandler(context, localID);
749   assert(!registeredID && "should have unregistered");
750   // Not strictly necessary but keeps stale pointers from being around to cause
751   // issues.
752   context = {nullptr};
753 }
754 
755 void PyDiagnostic::checkValid() {
756   if (!valid) {
757     throw std::invalid_argument(
758         "Diagnostic is invalid (used outside of callback)");
759   }
760 }
761 
762 MlirDiagnosticSeverity PyDiagnostic::getSeverity() {
763   checkValid();
764   return mlirDiagnosticGetSeverity(diagnostic);
765 }
766 
767 PyLocation PyDiagnostic::getLocation() {
768   checkValid();
769   MlirLocation loc = mlirDiagnosticGetLocation(diagnostic);
770   MlirContext context = mlirLocationGetContext(loc);
771   return PyLocation(PyMlirContext::forContext(context), loc);
772 }
773 
774 py::str PyDiagnostic::getMessage() {
775   checkValid();
776   py::object fileObject = py::module::import("io").attr("StringIO")();
777   PyFileAccumulator accum(fileObject, /*binary=*/false);
778   mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData());
779   return fileObject.attr("getvalue")();
780 }
781 
782 py::tuple PyDiagnostic::getNotes() {
783   checkValid();
784   if (materializedNotes)
785     return *materializedNotes;
786   intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic);
787   materializedNotes = py::tuple(numNotes);
788   for (intptr_t i = 0; i < numNotes; ++i) {
789     MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i);
790     py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag));
791     PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr());
792   }
793   return *materializedNotes;
794 }
795 
796 //------------------------------------------------------------------------------
797 // PyDialect, PyDialectDescriptor, PyDialects
798 //------------------------------------------------------------------------------
799 
800 MlirDialect PyDialects::getDialectForKey(const std::string &key,
801                                          bool attrError) {
802   MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(),
803                                                     {key.data(), key.size()});
804   if (mlirDialectIsNull(dialect)) {
805     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
806                      Twine("Dialect '") + key + "' not found");
807   }
808   return dialect;
809 }
810 
811 //------------------------------------------------------------------------------
812 // PyLocation
813 //------------------------------------------------------------------------------
814 
815 py::object PyLocation::getCapsule() {
816   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
817 }
818 
819 PyLocation PyLocation::createFromCapsule(py::object capsule) {
820   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
821   if (mlirLocationIsNull(rawLoc))
822     throw py::error_already_set();
823   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
824                     rawLoc);
825 }
826 
827 py::object PyLocation::contextEnter() {
828   return PyThreadContextEntry::pushLocation(*this);
829 }
830 
831 void PyLocation::contextExit(const pybind11::object &excType,
832                              const pybind11::object &excVal,
833                              const pybind11::object &excTb) {
834   PyThreadContextEntry::popLocation(*this);
835 }
836 
837 PyLocation &DefaultingPyLocation::resolve() {
838   auto *location = PyThreadContextEntry::getDefaultLocation();
839   if (!location) {
840     throw SetPyError(
841         PyExc_RuntimeError,
842         "An MLIR function requires a Location but none was provided in the "
843         "call or from the surrounding environment. Either pass to the function "
844         "with a 'loc=' argument or establish a default using 'with loc:'");
845   }
846   return *location;
847 }
848 
849 //------------------------------------------------------------------------------
850 // PyModule
851 //------------------------------------------------------------------------------
852 
853 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
854     : BaseContextObject(std::move(contextRef)), module(module) {}
855 
856 PyModule::~PyModule() {
857   py::gil_scoped_acquire acquire;
858   auto &liveModules = getContext()->liveModules;
859   assert(liveModules.count(module.ptr) == 1 &&
860          "destroying module not in live map");
861   liveModules.erase(module.ptr);
862   mlirModuleDestroy(module);
863 }
864 
865 PyModuleRef PyModule::forModule(MlirModule module) {
866   MlirContext context = mlirModuleGetContext(module);
867   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
868 
869   py::gil_scoped_acquire acquire;
870   auto &liveModules = contextRef->liveModules;
871   auto it = liveModules.find(module.ptr);
872   if (it == liveModules.end()) {
873     // Create.
874     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
875     // Note that the default return value policy on cast is automatic_reference,
876     // which does not take ownership (delete will not be called).
877     // Just be explicit.
878     py::object pyRef =
879         py::cast(unownedModule, py::return_value_policy::take_ownership);
880     unownedModule->handle = pyRef;
881     liveModules[module.ptr] =
882         std::make_pair(unownedModule->handle, unownedModule);
883     return PyModuleRef(unownedModule, std::move(pyRef));
884   }
885   // Use existing.
886   PyModule *existing = it->second.second;
887   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
888   return PyModuleRef(existing, std::move(pyRef));
889 }
890 
891 py::object PyModule::createFromCapsule(py::object capsule) {
892   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
893   if (mlirModuleIsNull(rawModule))
894     throw py::error_already_set();
895   return forModule(rawModule).releaseObject();
896 }
897 
898 py::object PyModule::getCapsule() {
899   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
900 }
901 
902 //------------------------------------------------------------------------------
903 // PyOperation
904 //------------------------------------------------------------------------------
905 
906 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
907     : BaseContextObject(std::move(contextRef)), operation(operation) {}
908 
909 PyOperation::~PyOperation() {
910   // If the operation has already been invalidated there is nothing to do.
911   if (!valid)
912     return;
913   auto &liveOperations = getContext()->liveOperations;
914   assert(liveOperations.count(operation.ptr) == 1 &&
915          "destroying operation not in live map");
916   liveOperations.erase(operation.ptr);
917   if (!isAttached()) {
918     mlirOperationDestroy(operation);
919   }
920 }
921 
922 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
923                                            MlirOperation operation,
924                                            py::object parentKeepAlive) {
925   auto &liveOperations = contextRef->liveOperations;
926   // Create.
927   PyOperation *unownedOperation =
928       new PyOperation(std::move(contextRef), operation);
929   // Note that the default return value policy on cast is automatic_reference,
930   // which does not take ownership (delete will not be called).
931   // Just be explicit.
932   py::object pyRef =
933       py::cast(unownedOperation, py::return_value_policy::take_ownership);
934   unownedOperation->handle = pyRef;
935   if (parentKeepAlive) {
936     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
937   }
938   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
939   return PyOperationRef(unownedOperation, std::move(pyRef));
940 }
941 
942 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
943                                          MlirOperation operation,
944                                          py::object parentKeepAlive) {
945   auto &liveOperations = contextRef->liveOperations;
946   auto it = liveOperations.find(operation.ptr);
947   if (it == liveOperations.end()) {
948     // Create.
949     return createInstance(std::move(contextRef), operation,
950                           std::move(parentKeepAlive));
951   }
952   // Use existing.
953   PyOperation *existing = it->second.second;
954   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
955   return PyOperationRef(existing, std::move(pyRef));
956 }
957 
958 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
959                                            MlirOperation operation,
960                                            py::object parentKeepAlive) {
961   auto &liveOperations = contextRef->liveOperations;
962   assert(liveOperations.count(operation.ptr) == 0 &&
963          "cannot create detached operation that already exists");
964   (void)liveOperations;
965 
966   PyOperationRef created = createInstance(std::move(contextRef), operation,
967                                           std::move(parentKeepAlive));
968   created->attached = false;
969   return created;
970 }
971 
972 void PyOperation::checkValid() const {
973   if (!valid) {
974     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
975   }
976 }
977 
978 void PyOperationBase::print(py::object fileObject, bool binary,
979                             llvm::Optional<int64_t> largeElementsLimit,
980                             bool enableDebugInfo, bool prettyDebugInfo,
981                             bool printGenericOpForm, bool useLocalScope,
982                             bool assumeVerified) {
983   PyOperation &operation = getOperation();
984   operation.checkValid();
985   if (fileObject.is_none())
986     fileObject = py::module::import("sys").attr("stdout");
987 
988   if (!assumeVerified && !printGenericOpForm &&
989       !mlirOperationVerify(operation)) {
990     std::string message("// Verification failed, printing generic form\n");
991     if (binary) {
992       fileObject.attr("write")(py::bytes(message));
993     } else {
994       fileObject.attr("write")(py::str(message));
995     }
996     printGenericOpForm = true;
997   }
998 
999   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
1000   if (largeElementsLimit)
1001     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
1002   if (enableDebugInfo)
1003     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
1004   if (printGenericOpForm)
1005     mlirOpPrintingFlagsPrintGenericOpForm(flags);
1006 
1007   PyFileAccumulator accum(fileObject, binary);
1008   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
1009                               accum.getUserData());
1010   mlirOpPrintingFlagsDestroy(flags);
1011 }
1012 
1013 py::object PyOperationBase::getAsm(bool binary,
1014                                    llvm::Optional<int64_t> largeElementsLimit,
1015                                    bool enableDebugInfo, bool prettyDebugInfo,
1016                                    bool printGenericOpForm, bool useLocalScope,
1017                                    bool assumeVerified) {
1018   py::object fileObject;
1019   if (binary) {
1020     fileObject = py::module::import("io").attr("BytesIO")();
1021   } else {
1022     fileObject = py::module::import("io").attr("StringIO")();
1023   }
1024   print(fileObject, /*binary=*/binary,
1025         /*largeElementsLimit=*/largeElementsLimit,
1026         /*enableDebugInfo=*/enableDebugInfo,
1027         /*prettyDebugInfo=*/prettyDebugInfo,
1028         /*printGenericOpForm=*/printGenericOpForm,
1029         /*useLocalScope=*/useLocalScope,
1030         /*assumeVerified=*/assumeVerified);
1031 
1032   return fileObject.attr("getvalue")();
1033 }
1034 
1035 void PyOperationBase::moveAfter(PyOperationBase &other) {
1036   PyOperation &operation = getOperation();
1037   PyOperation &otherOp = other.getOperation();
1038   operation.checkValid();
1039   otherOp.checkValid();
1040   mlirOperationMoveAfter(operation, otherOp);
1041   operation.parentKeepAlive = otherOp.parentKeepAlive;
1042 }
1043 
1044 void PyOperationBase::moveBefore(PyOperationBase &other) {
1045   PyOperation &operation = getOperation();
1046   PyOperation &otherOp = other.getOperation();
1047   operation.checkValid();
1048   otherOp.checkValid();
1049   mlirOperationMoveBefore(operation, otherOp);
1050   operation.parentKeepAlive = otherOp.parentKeepAlive;
1051 }
1052 
1053 llvm::Optional<PyOperationRef> PyOperation::getParentOperation() {
1054   checkValid();
1055   if (!isAttached())
1056     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
1057   MlirOperation operation = mlirOperationGetParentOperation(get());
1058   if (mlirOperationIsNull(operation))
1059     return {};
1060   return PyOperation::forOperation(getContext(), operation);
1061 }
1062 
1063 PyBlock PyOperation::getBlock() {
1064   checkValid();
1065   llvm::Optional<PyOperationRef> parentOperation = getParentOperation();
1066   MlirBlock block = mlirOperationGetBlock(get());
1067   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
1068   assert(parentOperation && "Operation has no parent");
1069   return PyBlock{std::move(*parentOperation), block};
1070 }
1071 
1072 py::object PyOperation::getCapsule() {
1073   checkValid();
1074   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
1075 }
1076 
1077 py::object PyOperation::createFromCapsule(py::object capsule) {
1078   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
1079   if (mlirOperationIsNull(rawOperation))
1080     throw py::error_already_set();
1081   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
1082   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
1083       .releaseObject();
1084 }
1085 
1086 static void maybeInsertOperation(PyOperationRef &op,
1087                                  const py::object &maybeIp) {
1088   // InsertPoint active?
1089   if (!maybeIp.is(py::cast(false))) {
1090     PyInsertionPoint *ip;
1091     if (maybeIp.is_none()) {
1092       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1093     } else {
1094       ip = py::cast<PyInsertionPoint *>(maybeIp);
1095     }
1096     if (ip)
1097       ip->insert(*op.get());
1098   }
1099 }
1100 
1101 py::object PyOperation::create(
1102     const std::string &name, llvm::Optional<std::vector<PyType *>> results,
1103     llvm::Optional<std::vector<PyValue *>> operands,
1104     llvm::Optional<py::dict> attributes,
1105     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
1106     DefaultingPyLocation location, const py::object &maybeIp) {
1107   llvm::SmallVector<MlirValue, 4> mlirOperands;
1108   llvm::SmallVector<MlirType, 4> mlirResults;
1109   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
1110   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
1111 
1112   // General parameter validation.
1113   if (regions < 0)
1114     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
1115 
1116   // Unpack/validate operands.
1117   if (operands) {
1118     mlirOperands.reserve(operands->size());
1119     for (PyValue *operand : *operands) {
1120       if (!operand)
1121         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
1122       mlirOperands.push_back(operand->get());
1123     }
1124   }
1125 
1126   // Unpack/validate results.
1127   if (results) {
1128     mlirResults.reserve(results->size());
1129     for (PyType *result : *results) {
1130       // TODO: Verify result type originate from the same context.
1131       if (!result)
1132         throw SetPyError(PyExc_ValueError, "result type cannot be None");
1133       mlirResults.push_back(*result);
1134     }
1135   }
1136   // Unpack/validate attributes.
1137   if (attributes) {
1138     mlirAttributes.reserve(attributes->size());
1139     for (auto &it : *attributes) {
1140       std::string key;
1141       try {
1142         key = it.first.cast<std::string>();
1143       } catch (py::cast_error &err) {
1144         std::string msg = "Invalid attribute key (not a string) when "
1145                           "attempting to create the operation \"" +
1146                           name + "\" (" + err.what() + ")";
1147         throw py::cast_error(msg);
1148       }
1149       try {
1150         auto &attribute = it.second.cast<PyAttribute &>();
1151         // TODO: Verify attribute originates from the same context.
1152         mlirAttributes.emplace_back(std::move(key), attribute);
1153       } catch (py::reference_cast_error &) {
1154         // This exception seems thrown when the value is "None".
1155         std::string msg =
1156             "Found an invalid (`None`?) attribute value for the key \"" + key +
1157             "\" when attempting to create the operation \"" + name + "\"";
1158         throw py::cast_error(msg);
1159       } catch (py::cast_error &err) {
1160         std::string msg = "Invalid attribute value for the key \"" + key +
1161                           "\" when attempting to create the operation \"" +
1162                           name + "\" (" + err.what() + ")";
1163         throw py::cast_error(msg);
1164       }
1165     }
1166   }
1167   // Unpack/validate successors.
1168   if (successors) {
1169     mlirSuccessors.reserve(successors->size());
1170     for (auto *successor : *successors) {
1171       // TODO: Verify successor originate from the same context.
1172       if (!successor)
1173         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
1174       mlirSuccessors.push_back(successor->get());
1175     }
1176   }
1177 
1178   // Apply unpacked/validated to the operation state. Beyond this
1179   // point, exceptions cannot be thrown or else the state will leak.
1180   MlirOperationState state =
1181       mlirOperationStateGet(toMlirStringRef(name), location);
1182   if (!mlirOperands.empty())
1183     mlirOperationStateAddOperands(&state, mlirOperands.size(),
1184                                   mlirOperands.data());
1185   if (!mlirResults.empty())
1186     mlirOperationStateAddResults(&state, mlirResults.size(),
1187                                  mlirResults.data());
1188   if (!mlirAttributes.empty()) {
1189     // Note that the attribute names directly reference bytes in
1190     // mlirAttributes, so that vector must not be changed from here
1191     // on.
1192     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
1193     mlirNamedAttributes.reserve(mlirAttributes.size());
1194     for (auto &it : mlirAttributes)
1195       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
1196           mlirIdentifierGet(mlirAttributeGetContext(it.second),
1197                             toMlirStringRef(it.first)),
1198           it.second));
1199     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
1200                                     mlirNamedAttributes.data());
1201   }
1202   if (!mlirSuccessors.empty())
1203     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
1204                                     mlirSuccessors.data());
1205   if (regions) {
1206     llvm::SmallVector<MlirRegion, 4> mlirRegions;
1207     mlirRegions.resize(regions);
1208     for (int i = 0; i < regions; ++i)
1209       mlirRegions[i] = mlirRegionCreate();
1210     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
1211                                       mlirRegions.data());
1212   }
1213 
1214   // Construct the operation.
1215   MlirOperation operation = mlirOperationCreate(&state);
1216   PyOperationRef created =
1217       PyOperation::createDetached(location->getContext(), operation);
1218   maybeInsertOperation(created, maybeIp);
1219 
1220   return created->createOpView();
1221 }
1222 
1223 py::object PyOperation::clone(const py::object &maybeIp) {
1224   MlirOperation clonedOperation = mlirOperationClone(operation);
1225   PyOperationRef cloned =
1226       PyOperation::createDetached(getContext(), clonedOperation);
1227   maybeInsertOperation(cloned, maybeIp);
1228 
1229   return cloned->createOpView();
1230 }
1231 
1232 py::object PyOperation::createOpView() {
1233   checkValid();
1234   MlirIdentifier ident = mlirOperationGetName(get());
1235   MlirStringRef identStr = mlirIdentifierStr(ident);
1236   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1237       StringRef(identStr.data, identStr.length));
1238   if (opViewClass)
1239     return (*opViewClass)(getRef().getObject());
1240   return py::cast(PyOpView(getRef().getObject()));
1241 }
1242 
1243 void PyOperation::erase() {
1244   checkValid();
1245   // TODO: Fix memory hazards when erasing a tree of operations for which a deep
1246   // Python reference to a child operation is live. All children should also
1247   // have their `valid` bit set to false.
1248   auto &liveOperations = getContext()->liveOperations;
1249   if (liveOperations.count(operation.ptr))
1250     liveOperations.erase(operation.ptr);
1251   mlirOperationDestroy(operation);
1252   valid = false;
1253 }
1254 
1255 //------------------------------------------------------------------------------
1256 // PyOpView
1257 //------------------------------------------------------------------------------
1258 
1259 py::object PyOpView::buildGeneric(
1260     const py::object &cls, py::list resultTypeList, py::list operandList,
1261     llvm::Optional<py::dict> attributes,
1262     llvm::Optional<std::vector<PyBlock *>> successors,
1263     llvm::Optional<int> regions, DefaultingPyLocation location,
1264     const py::object &maybeIp) {
1265   PyMlirContextRef context = location->getContext();
1266   // Class level operation construction metadata.
1267   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1268   // Operand and result segment specs are either none, which does no
1269   // variadic unpacking, or a list of ints with segment sizes, where each
1270   // element is either a positive number (typically 1 for a scalar) or -1 to
1271   // indicate that it is derived from the length of the same-indexed operand
1272   // or result (implying that it is a list at that position).
1273   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1274   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1275 
1276   std::vector<uint32_t> operandSegmentLengths;
1277   std::vector<uint32_t> resultSegmentLengths;
1278 
1279   // Validate/determine region count.
1280   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1281   int opMinRegionCount = std::get<0>(opRegionSpec);
1282   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1283   if (!regions) {
1284     regions = opMinRegionCount;
1285   }
1286   if (*regions < opMinRegionCount) {
1287     throw py::value_error(
1288         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1289          llvm::Twine(opMinRegionCount) +
1290          " regions but was built with regions=" + llvm::Twine(*regions))
1291             .str());
1292   }
1293   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1294     throw py::value_error(
1295         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1296          llvm::Twine(opMinRegionCount) +
1297          " regions but was built with regions=" + llvm::Twine(*regions))
1298             .str());
1299   }
1300 
1301   // Unpack results.
1302   std::vector<PyType *> resultTypes;
1303   resultTypes.reserve(resultTypeList.size());
1304   if (resultSegmentSpecObj.is_none()) {
1305     // Non-variadic result unpacking.
1306     for (const auto &it : llvm::enumerate(resultTypeList)) {
1307       try {
1308         resultTypes.push_back(py::cast<PyType *>(it.value()));
1309         if (!resultTypes.back())
1310           throw py::cast_error();
1311       } catch (py::cast_error &err) {
1312         throw py::value_error((llvm::Twine("Result ") +
1313                                llvm::Twine(it.index()) + " of operation \"" +
1314                                name + "\" must be a Type (" + err.what() + ")")
1315                                   .str());
1316       }
1317     }
1318   } else {
1319     // Sized result unpacking.
1320     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1321     if (resultSegmentSpec.size() != resultTypeList.size()) {
1322       throw py::value_error((llvm::Twine("Operation \"") + name +
1323                              "\" requires " +
1324                              llvm::Twine(resultSegmentSpec.size()) +
1325                              " result segments but was provided " +
1326                              llvm::Twine(resultTypeList.size()))
1327                                 .str());
1328     }
1329     resultSegmentLengths.reserve(resultTypeList.size());
1330     for (const auto &it :
1331          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1332       int segmentSpec = std::get<1>(it.value());
1333       if (segmentSpec == 1 || segmentSpec == 0) {
1334         // Unpack unary element.
1335         try {
1336           auto *resultType = py::cast<PyType *>(std::get<0>(it.value()));
1337           if (resultType) {
1338             resultTypes.push_back(resultType);
1339             resultSegmentLengths.push_back(1);
1340           } else if (segmentSpec == 0) {
1341             // Allowed to be optional.
1342             resultSegmentLengths.push_back(0);
1343           } else {
1344             throw py::cast_error("was None and result is not optional");
1345           }
1346         } catch (py::cast_error &err) {
1347           throw py::value_error((llvm::Twine("Result ") +
1348                                  llvm::Twine(it.index()) + " of operation \"" +
1349                                  name + "\" must be a Type (" + err.what() +
1350                                  ")")
1351                                     .str());
1352         }
1353       } else if (segmentSpec == -1) {
1354         // Unpack sequence by appending.
1355         try {
1356           if (std::get<0>(it.value()).is_none()) {
1357             // Treat it as an empty list.
1358             resultSegmentLengths.push_back(0);
1359           } else {
1360             // Unpack the list.
1361             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1362             for (py::object segmentItem : segment) {
1363               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1364               if (!resultTypes.back()) {
1365                 throw py::cast_error("contained a None item");
1366               }
1367             }
1368             resultSegmentLengths.push_back(segment.size());
1369           }
1370         } catch (std::exception &err) {
1371           // NOTE: Sloppy to be using a catch-all here, but there are at least
1372           // three different unrelated exceptions that can be thrown in the
1373           // above "casts". Just keep the scope above small and catch them all.
1374           throw py::value_error((llvm::Twine("Result ") +
1375                                  llvm::Twine(it.index()) + " of operation \"" +
1376                                  name + "\" must be a Sequence of Types (" +
1377                                  err.what() + ")")
1378                                     .str());
1379         }
1380       } else {
1381         throw py::value_error("Unexpected segment spec");
1382       }
1383     }
1384   }
1385 
1386   // Unpack operands.
1387   std::vector<PyValue *> operands;
1388   operands.reserve(operands.size());
1389   if (operandSegmentSpecObj.is_none()) {
1390     // Non-sized operand unpacking.
1391     for (const auto &it : llvm::enumerate(operandList)) {
1392       try {
1393         operands.push_back(py::cast<PyValue *>(it.value()));
1394         if (!operands.back())
1395           throw py::cast_error();
1396       } catch (py::cast_error &err) {
1397         throw py::value_error((llvm::Twine("Operand ") +
1398                                llvm::Twine(it.index()) + " of operation \"" +
1399                                name + "\" must be a Value (" + err.what() + ")")
1400                                   .str());
1401       }
1402     }
1403   } else {
1404     // Sized operand unpacking.
1405     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1406     if (operandSegmentSpec.size() != operandList.size()) {
1407       throw py::value_error((llvm::Twine("Operation \"") + name +
1408                              "\" requires " +
1409                              llvm::Twine(operandSegmentSpec.size()) +
1410                              "operand segments but was provided " +
1411                              llvm::Twine(operandList.size()))
1412                                 .str());
1413     }
1414     operandSegmentLengths.reserve(operandList.size());
1415     for (const auto &it :
1416          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1417       int segmentSpec = std::get<1>(it.value());
1418       if (segmentSpec == 1 || segmentSpec == 0) {
1419         // Unpack unary element.
1420         try {
1421           auto *operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1422           if (operandValue) {
1423             operands.push_back(operandValue);
1424             operandSegmentLengths.push_back(1);
1425           } else if (segmentSpec == 0) {
1426             // Allowed to be optional.
1427             operandSegmentLengths.push_back(0);
1428           } else {
1429             throw py::cast_error("was None and operand is not optional");
1430           }
1431         } catch (py::cast_error &err) {
1432           throw py::value_error((llvm::Twine("Operand ") +
1433                                  llvm::Twine(it.index()) + " of operation \"" +
1434                                  name + "\" must be a Value (" + err.what() +
1435                                  ")")
1436                                     .str());
1437         }
1438       } else if (segmentSpec == -1) {
1439         // Unpack sequence by appending.
1440         try {
1441           if (std::get<0>(it.value()).is_none()) {
1442             // Treat it as an empty list.
1443             operandSegmentLengths.push_back(0);
1444           } else {
1445             // Unpack the list.
1446             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1447             for (py::object segmentItem : segment) {
1448               operands.push_back(py::cast<PyValue *>(segmentItem));
1449               if (!operands.back()) {
1450                 throw py::cast_error("contained a None item");
1451               }
1452             }
1453             operandSegmentLengths.push_back(segment.size());
1454           }
1455         } catch (std::exception &err) {
1456           // NOTE: Sloppy to be using a catch-all here, but there are at least
1457           // three different unrelated exceptions that can be thrown in the
1458           // above "casts". Just keep the scope above small and catch them all.
1459           throw py::value_error((llvm::Twine("Operand ") +
1460                                  llvm::Twine(it.index()) + " of operation \"" +
1461                                  name + "\" must be a Sequence of Values (" +
1462                                  err.what() + ")")
1463                                     .str());
1464         }
1465       } else {
1466         throw py::value_error("Unexpected segment spec");
1467       }
1468     }
1469   }
1470 
1471   // Merge operand/result segment lengths into attributes if needed.
1472   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1473     // Dup.
1474     if (attributes) {
1475       attributes = py::dict(*attributes);
1476     } else {
1477       attributes = py::dict();
1478     }
1479     if (attributes->contains("result_segment_sizes") ||
1480         attributes->contains("operand_segment_sizes")) {
1481       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1482                             "'operand_segment_sizes' attribute is unsupported. "
1483                             "Use Operation.create for such low-level access.");
1484     }
1485 
1486     // Add result_segment_sizes attribute.
1487     if (!resultSegmentLengths.empty()) {
1488       int64_t size = resultSegmentLengths.size();
1489       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1490           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1491           resultSegmentLengths.size(), resultSegmentLengths.data());
1492       (*attributes)["result_segment_sizes"] =
1493           PyAttribute(context, segmentLengthAttr);
1494     }
1495 
1496     // Add operand_segment_sizes attribute.
1497     if (!operandSegmentLengths.empty()) {
1498       int64_t size = operandSegmentLengths.size();
1499       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1500           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1501           operandSegmentLengths.size(), operandSegmentLengths.data());
1502       (*attributes)["operand_segment_sizes"] =
1503           PyAttribute(context, segmentLengthAttr);
1504     }
1505   }
1506 
1507   // Delegate to create.
1508   return PyOperation::create(name,
1509                              /*results=*/std::move(resultTypes),
1510                              /*operands=*/std::move(operands),
1511                              /*attributes=*/std::move(attributes),
1512                              /*successors=*/std::move(successors),
1513                              /*regions=*/*regions, location, maybeIp);
1514 }
1515 
1516 PyOpView::PyOpView(const py::object &operationObject)
1517     // Casting through the PyOperationBase base-class and then back to the
1518     // Operation lets us accept any PyOperationBase subclass.
1519     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1520       operationObject(operation.getRef().getObject()) {}
1521 
1522 py::object PyOpView::createRawSubclass(const py::object &userClass) {
1523   // This is... a little gross. The typical pattern is to have a pure python
1524   // class that extends OpView like:
1525   //   class AddFOp(_cext.ir.OpView):
1526   //     def __init__(self, loc, lhs, rhs):
1527   //       operation = loc.context.create_operation(
1528   //           "addf", lhs, rhs, results=[lhs.type])
1529   //       super().__init__(operation)
1530   //
1531   // I.e. The goal of the user facing type is to provide a nice constructor
1532   // that has complete freedom for the op under construction. This is at odds
1533   // with our other desire to sometimes create this object by just passing an
1534   // operation (to initialize the base class). We could do *arg and **kwargs
1535   // munging to try to make it work, but instead, we synthesize a new class
1536   // on the fly which extends this user class (AddFOp in this example) and
1537   // *give it* the base class's __init__ method, thus bypassing the
1538   // intermediate subclass's __init__ method entirely. While slightly,
1539   // underhanded, this is safe/legal because the type hierarchy has not changed
1540   // (we just added a new leaf) and we aren't mucking around with __new__.
1541   // Typically, this new class will be stored on the original as "_Raw" and will
1542   // be used for casts and other things that need a variant of the class that
1543   // is initialized purely from an operation.
1544   py::object parentMetaclass =
1545       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1546   py::dict attributes;
1547   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1548   // now.
1549   //   auto opViewType = py::type::of<PyOpView>();
1550   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1551   attributes["__init__"] = opViewType.attr("__init__");
1552   py::str origName = userClass.attr("__name__");
1553   py::str newName = py::str("_") + origName;
1554   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1555 }
1556 
1557 //------------------------------------------------------------------------------
1558 // PyInsertionPoint.
1559 //------------------------------------------------------------------------------
1560 
1561 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1562 
1563 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1564     : refOperation(beforeOperationBase.getOperation().getRef()),
1565       block((*refOperation)->getBlock()) {}
1566 
1567 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1568   PyOperation &operation = operationBase.getOperation();
1569   if (operation.isAttached())
1570     throw SetPyError(PyExc_ValueError,
1571                      "Attempt to insert operation that is already attached");
1572   block.getParentOperation()->checkValid();
1573   MlirOperation beforeOp = {nullptr};
1574   if (refOperation) {
1575     // Insert before operation.
1576     (*refOperation)->checkValid();
1577     beforeOp = (*refOperation)->get();
1578   } else {
1579     // Insert at end (before null) is only valid if the block does not
1580     // already end in a known terminator (violating this will cause assertion
1581     // failures later).
1582     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1583       throw py::index_error("Cannot insert operation at the end of a block "
1584                             "that already has a terminator. Did you mean to "
1585                             "use 'InsertionPoint.at_block_terminator(block)' "
1586                             "versus 'InsertionPoint(block)'?");
1587     }
1588   }
1589   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1590   operation.setAttached();
1591 }
1592 
1593 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1594   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1595   if (mlirOperationIsNull(firstOp)) {
1596     // Just insert at end.
1597     return PyInsertionPoint(block);
1598   }
1599 
1600   // Insert before first op.
1601   PyOperationRef firstOpRef = PyOperation::forOperation(
1602       block.getParentOperation()->getContext(), firstOp);
1603   return PyInsertionPoint{block, std::move(firstOpRef)};
1604 }
1605 
1606 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1607   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1608   if (mlirOperationIsNull(terminator))
1609     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1610   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1611       block.getParentOperation()->getContext(), terminator);
1612   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1613 }
1614 
1615 py::object PyInsertionPoint::contextEnter() {
1616   return PyThreadContextEntry::pushInsertionPoint(*this);
1617 }
1618 
1619 void PyInsertionPoint::contextExit(const pybind11::object &excType,
1620                                    const pybind11::object &excVal,
1621                                    const pybind11::object &excTb) {
1622   PyThreadContextEntry::popInsertionPoint(*this);
1623 }
1624 
1625 //------------------------------------------------------------------------------
1626 // PyAttribute.
1627 //------------------------------------------------------------------------------
1628 
1629 bool PyAttribute::operator==(const PyAttribute &other) {
1630   return mlirAttributeEqual(attr, other.attr);
1631 }
1632 
1633 py::object PyAttribute::getCapsule() {
1634   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1635 }
1636 
1637 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1638   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1639   if (mlirAttributeIsNull(rawAttr))
1640     throw py::error_already_set();
1641   return PyAttribute(
1642       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1643 }
1644 
1645 //------------------------------------------------------------------------------
1646 // PyNamedAttribute.
1647 //------------------------------------------------------------------------------
1648 
1649 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1650     : ownedName(new std::string(std::move(ownedName))) {
1651   namedAttr = mlirNamedAttributeGet(
1652       mlirIdentifierGet(mlirAttributeGetContext(attr),
1653                         toMlirStringRef(*this->ownedName)),
1654       attr);
1655 }
1656 
1657 //------------------------------------------------------------------------------
1658 // PyType.
1659 //------------------------------------------------------------------------------
1660 
1661 bool PyType::operator==(const PyType &other) {
1662   return mlirTypeEqual(type, other.type);
1663 }
1664 
1665 py::object PyType::getCapsule() {
1666   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1667 }
1668 
1669 PyType PyType::createFromCapsule(py::object capsule) {
1670   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1671   if (mlirTypeIsNull(rawType))
1672     throw py::error_already_set();
1673   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1674                 rawType);
1675 }
1676 
1677 //------------------------------------------------------------------------------
1678 // PyValue and subclases.
1679 //------------------------------------------------------------------------------
1680 
1681 pybind11::object PyValue::getCapsule() {
1682   return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
1683 }
1684 
1685 PyValue PyValue::createFromCapsule(pybind11::object capsule) {
1686   MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
1687   if (mlirValueIsNull(value))
1688     throw py::error_already_set();
1689   MlirOperation owner;
1690   if (mlirValueIsAOpResult(value))
1691     owner = mlirOpResultGetOwner(value);
1692   if (mlirValueIsABlockArgument(value))
1693     owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value));
1694   if (mlirOperationIsNull(owner))
1695     throw py::error_already_set();
1696   MlirContext ctx = mlirOperationGetContext(owner);
1697   PyOperationRef ownerRef =
1698       PyOperation::forOperation(PyMlirContext::forContext(ctx), owner);
1699   return PyValue(ownerRef, value);
1700 }
1701 
1702 //------------------------------------------------------------------------------
1703 // PySymbolTable.
1704 //------------------------------------------------------------------------------
1705 
1706 PySymbolTable::PySymbolTable(PyOperationBase &operation)
1707     : operation(operation.getOperation().getRef()) {
1708   symbolTable = mlirSymbolTableCreate(operation.getOperation().get());
1709   if (mlirSymbolTableIsNull(symbolTable)) {
1710     throw py::cast_error("Operation is not a Symbol Table.");
1711   }
1712 }
1713 
1714 py::object PySymbolTable::dunderGetItem(const std::string &name) {
1715   operation->checkValid();
1716   MlirOperation symbol = mlirSymbolTableLookup(
1717       symbolTable, mlirStringRefCreate(name.data(), name.length()));
1718   if (mlirOperationIsNull(symbol))
1719     throw py::key_error("Symbol '" + name + "' not in the symbol table.");
1720 
1721   return PyOperation::forOperation(operation->getContext(), symbol,
1722                                    operation.getObject())
1723       ->createOpView();
1724 }
1725 
1726 void PySymbolTable::erase(PyOperationBase &symbol) {
1727   operation->checkValid();
1728   symbol.getOperation().checkValid();
1729   mlirSymbolTableErase(symbolTable, symbol.getOperation().get());
1730   // The operation is also erased, so we must invalidate it. There may be Python
1731   // references to this operation so we don't want to delete it from the list of
1732   // live operations here.
1733   symbol.getOperation().valid = false;
1734 }
1735 
1736 void PySymbolTable::dunderDel(const std::string &name) {
1737   py::object operation = dunderGetItem(name);
1738   erase(py::cast<PyOperationBase &>(operation));
1739 }
1740 
1741 PyAttribute PySymbolTable::insert(PyOperationBase &symbol) {
1742   operation->checkValid();
1743   symbol.getOperation().checkValid();
1744   MlirAttribute symbolAttr = mlirOperationGetAttributeByName(
1745       symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName());
1746   if (mlirAttributeIsNull(symbolAttr))
1747     throw py::value_error("Expected operation to have a symbol name.");
1748   return PyAttribute(
1749       symbol.getOperation().getContext(),
1750       mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()));
1751 }
1752 
1753 PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) {
1754   // Op must already be a symbol.
1755   PyOperation &operation = symbol.getOperation();
1756   operation.checkValid();
1757   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1758   MlirAttribute existingNameAttr =
1759       mlirOperationGetAttributeByName(operation.get(), attrName);
1760   if (mlirAttributeIsNull(existingNameAttr))
1761     throw py::value_error("Expected operation to have a symbol name.");
1762   return PyAttribute(symbol.getOperation().getContext(), existingNameAttr);
1763 }
1764 
1765 void PySymbolTable::setSymbolName(PyOperationBase &symbol,
1766                                   const std::string &name) {
1767   // Op must already be a symbol.
1768   PyOperation &operation = symbol.getOperation();
1769   operation.checkValid();
1770   MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName();
1771   MlirAttribute existingNameAttr =
1772       mlirOperationGetAttributeByName(operation.get(), attrName);
1773   if (mlirAttributeIsNull(existingNameAttr))
1774     throw py::value_error("Expected operation to have a symbol name.");
1775   MlirAttribute newNameAttr =
1776       mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name));
1777   mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr);
1778 }
1779 
1780 PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) {
1781   PyOperation &operation = symbol.getOperation();
1782   operation.checkValid();
1783   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1784   MlirAttribute existingVisAttr =
1785       mlirOperationGetAttributeByName(operation.get(), attrName);
1786   if (mlirAttributeIsNull(existingVisAttr))
1787     throw py::value_error("Expected operation to have a symbol visibility.");
1788   return PyAttribute(symbol.getOperation().getContext(), existingVisAttr);
1789 }
1790 
1791 void PySymbolTable::setVisibility(PyOperationBase &symbol,
1792                                   const std::string &visibility) {
1793   if (visibility != "public" && visibility != "private" &&
1794       visibility != "nested")
1795     throw py::value_error(
1796         "Expected visibility to be 'public', 'private' or 'nested'");
1797   PyOperation &operation = symbol.getOperation();
1798   operation.checkValid();
1799   MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName();
1800   MlirAttribute existingVisAttr =
1801       mlirOperationGetAttributeByName(operation.get(), attrName);
1802   if (mlirAttributeIsNull(existingVisAttr))
1803     throw py::value_error("Expected operation to have a symbol visibility.");
1804   MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(),
1805                                                toMlirStringRef(visibility));
1806   mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr);
1807 }
1808 
1809 void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol,
1810                                          const std::string &newSymbol,
1811                                          PyOperationBase &from) {
1812   PyOperation &fromOperation = from.getOperation();
1813   fromOperation.checkValid();
1814   if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses(
1815           toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol),
1816           from.getOperation())))
1817 
1818     throw py::value_error("Symbol rename failed");
1819 }
1820 
1821 void PySymbolTable::walkSymbolTables(PyOperationBase &from,
1822                                      bool allSymUsesVisible,
1823                                      py::object callback) {
1824   PyOperation &fromOperation = from.getOperation();
1825   fromOperation.checkValid();
1826   struct UserData {
1827     PyMlirContextRef context;
1828     py::object callback;
1829     bool gotException;
1830     std::string exceptionWhat;
1831     py::object exceptionType;
1832   };
1833   UserData userData{
1834       fromOperation.getContext(), std::move(callback), false, {}, {}};
1835   mlirSymbolTableWalkSymbolTables(
1836       fromOperation.get(), allSymUsesVisible,
1837       [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) {
1838         UserData *calleeUserData = static_cast<UserData *>(calleeUserDataVoid);
1839         auto pyFoundOp =
1840             PyOperation::forOperation(calleeUserData->context, foundOp);
1841         if (calleeUserData->gotException)
1842           return;
1843         try {
1844           calleeUserData->callback(pyFoundOp.getObject(), isVisible);
1845         } catch (py::error_already_set &e) {
1846           calleeUserData->gotException = true;
1847           calleeUserData->exceptionWhat = e.what();
1848           calleeUserData->exceptionType = e.type();
1849         }
1850       },
1851       static_cast<void *>(&userData));
1852   if (userData.gotException) {
1853     std::string message("Exception raised in callback: ");
1854     message.append(userData.exceptionWhat);
1855     throw std::runtime_error(message);
1856   }
1857 }
1858 
1859 namespace {
1860 /// CRTP base class for Python MLIR values that subclass Value and should be
1861 /// castable from it. The value hierarchy is one level deep and is not supposed
1862 /// to accommodate other levels unless core MLIR changes.
1863 template <typename DerivedTy>
1864 class PyConcreteValue : public PyValue {
1865 public:
1866   // Derived classes must define statics for:
1867   //   IsAFunctionTy isaFunction
1868   //   const char *pyClassName
1869   // and redefine bindDerived.
1870   using ClassTy = py::class_<DerivedTy, PyValue>;
1871   using IsAFunctionTy = bool (*)(MlirValue);
1872 
1873   PyConcreteValue() = default;
1874   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1875       : PyValue(operationRef, value) {}
1876   PyConcreteValue(PyValue &orig)
1877       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1878 
1879   /// Attempts to cast the original value to the derived type and throws on
1880   /// type mismatches.
1881   static MlirValue castFrom(PyValue &orig) {
1882     if (!DerivedTy::isaFunction(orig.get())) {
1883       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1884       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1885                                              DerivedTy::pyClassName +
1886                                              " (from " + origRepr + ")");
1887     }
1888     return orig.get();
1889   }
1890 
1891   /// Binds the Python module objects to functions of this class.
1892   static void bind(py::module &m) {
1893     auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local());
1894     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>(), py::arg("value"));
1895     cls.def_static(
1896         "isinstance",
1897         [](PyValue &otherValue) -> bool {
1898           return DerivedTy::isaFunction(otherValue);
1899         },
1900         py::arg("other_value"));
1901     DerivedTy::bindDerived(cls);
1902   }
1903 
1904   /// Implemented by derived classes to add methods to the Python subclass.
1905   static void bindDerived(ClassTy &m) {}
1906 };
1907 
1908 /// Python wrapper for MlirBlockArgument.
1909 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1910 public:
1911   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1912   static constexpr const char *pyClassName = "BlockArgument";
1913   using PyConcreteValue::PyConcreteValue;
1914 
1915   static void bindDerived(ClassTy &c) {
1916     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1917       return PyBlock(self.getParentOperation(),
1918                      mlirBlockArgumentGetOwner(self.get()));
1919     });
1920     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1921       return mlirBlockArgumentGetArgNumber(self.get());
1922     });
1923     c.def(
1924         "set_type",
1925         [](PyBlockArgument &self, PyType type) {
1926           return mlirBlockArgumentSetType(self.get(), type);
1927         },
1928         py::arg("type"));
1929   }
1930 };
1931 
1932 /// Python wrapper for MlirOpResult.
1933 class PyOpResult : public PyConcreteValue<PyOpResult> {
1934 public:
1935   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1936   static constexpr const char *pyClassName = "OpResult";
1937   using PyConcreteValue::PyConcreteValue;
1938 
1939   static void bindDerived(ClassTy &c) {
1940     c.def_property_readonly("owner", [](PyOpResult &self) {
1941       assert(
1942           mlirOperationEqual(self.getParentOperation()->get(),
1943                              mlirOpResultGetOwner(self.get())) &&
1944           "expected the owner of the value in Python to match that in the IR");
1945       return self.getParentOperation().getObject();
1946     });
1947     c.def_property_readonly("result_number", [](PyOpResult &self) {
1948       return mlirOpResultGetResultNumber(self.get());
1949     });
1950   }
1951 };
1952 
1953 /// Returns the list of types of the values held by container.
1954 template <typename Container>
1955 static std::vector<PyType> getValueTypes(Container &container,
1956                                          PyMlirContextRef &context) {
1957   std::vector<PyType> result;
1958   result.reserve(container.getNumElements());
1959   for (int i = 0, e = container.getNumElements(); i < e; ++i) {
1960     result.push_back(
1961         PyType(context, mlirValueGetType(container.getElement(i).get())));
1962   }
1963   return result;
1964 }
1965 
1966 /// A list of block arguments. Internally, these are stored as consecutive
1967 /// elements, random access is cheap. The argument list is associated with the
1968 /// operation that contains the block (detached blocks are not allowed in
1969 /// Python bindings) and extends its lifetime.
1970 class PyBlockArgumentList
1971     : public Sliceable<PyBlockArgumentList, PyBlockArgument> {
1972 public:
1973   static constexpr const char *pyClassName = "BlockArgumentList";
1974 
1975   PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
1976                       intptr_t startIndex = 0, intptr_t length = -1,
1977                       intptr_t step = 1)
1978       : Sliceable(startIndex,
1979                   length == -1 ? mlirBlockGetNumArguments(block) : length,
1980                   step),
1981         operation(std::move(operation)), block(block) {}
1982 
1983   /// Returns the number of arguments in the list.
1984   intptr_t getNumElements() {
1985     operation->checkValid();
1986     return mlirBlockGetNumArguments(block);
1987   }
1988 
1989   /// Returns `pos`-the element in the list. Asserts on out-of-bounds.
1990   PyBlockArgument getElement(intptr_t pos) {
1991     MlirValue argument = mlirBlockGetArgument(block, pos);
1992     return PyBlockArgument(operation, argument);
1993   }
1994 
1995   /// Returns a sublist of this list.
1996   PyBlockArgumentList slice(intptr_t startIndex, intptr_t length,
1997                             intptr_t step) {
1998     return PyBlockArgumentList(operation, block, startIndex, length, step);
1999   }
2000 
2001   static void bindDerived(ClassTy &c) {
2002     c.def_property_readonly("types", [](PyBlockArgumentList &self) {
2003       return getValueTypes(self, self.operation->getContext());
2004     });
2005   }
2006 
2007 private:
2008   PyOperationRef operation;
2009   MlirBlock block;
2010 };
2011 
2012 /// A list of operation operands. Internally, these are stored as consecutive
2013 /// elements, random access is cheap. The result list is associated with the
2014 /// operation whose results these are, and extends the lifetime of this
2015 /// operation.
2016 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
2017 public:
2018   static constexpr const char *pyClassName = "OpOperandList";
2019 
2020   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
2021                   intptr_t length = -1, intptr_t step = 1)
2022       : Sliceable(startIndex,
2023                   length == -1 ? mlirOperationGetNumOperands(operation->get())
2024                                : length,
2025                   step),
2026         operation(operation) {}
2027 
2028   intptr_t getNumElements() {
2029     operation->checkValid();
2030     return mlirOperationGetNumOperands(operation->get());
2031   }
2032 
2033   PyValue getElement(intptr_t pos) {
2034     MlirValue operand = mlirOperationGetOperand(operation->get(), pos);
2035     MlirOperation owner;
2036     if (mlirValueIsAOpResult(operand))
2037       owner = mlirOpResultGetOwner(operand);
2038     else if (mlirValueIsABlockArgument(operand))
2039       owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand));
2040     else
2041       assert(false && "Value must be an block arg or op result.");
2042     PyOperationRef pyOwner =
2043         PyOperation::forOperation(operation->getContext(), owner);
2044     return PyValue(pyOwner, operand);
2045   }
2046 
2047   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2048     return PyOpOperandList(operation, startIndex, length, step);
2049   }
2050 
2051   void dunderSetItem(intptr_t index, PyValue value) {
2052     index = wrapIndex(index);
2053     mlirOperationSetOperand(operation->get(), index, value.get());
2054   }
2055 
2056   static void bindDerived(ClassTy &c) {
2057     c.def("__setitem__", &PyOpOperandList::dunderSetItem);
2058   }
2059 
2060 private:
2061   PyOperationRef operation;
2062 };
2063 
2064 /// A list of operation results. Internally, these are stored as consecutive
2065 /// elements, random access is cheap. The result list is associated with the
2066 /// operation whose results these are, and extends the lifetime of this
2067 /// operation.
2068 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
2069 public:
2070   static constexpr const char *pyClassName = "OpResultList";
2071 
2072   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
2073                  intptr_t length = -1, intptr_t step = 1)
2074       : Sliceable(startIndex,
2075                   length == -1 ? mlirOperationGetNumResults(operation->get())
2076                                : length,
2077                   step),
2078         operation(operation) {}
2079 
2080   intptr_t getNumElements() {
2081     operation->checkValid();
2082     return mlirOperationGetNumResults(operation->get());
2083   }
2084 
2085   PyOpResult getElement(intptr_t index) {
2086     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
2087     return PyOpResult(value);
2088   }
2089 
2090   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
2091     return PyOpResultList(operation, startIndex, length, step);
2092   }
2093 
2094   static void bindDerived(ClassTy &c) {
2095     c.def_property_readonly("types", [](PyOpResultList &self) {
2096       return getValueTypes(self, self.operation->getContext());
2097     });
2098   }
2099 
2100 private:
2101   PyOperationRef operation;
2102 };
2103 
2104 /// A list of operation attributes. Can be indexed by name, producing
2105 /// attributes, or by index, producing named attributes.
2106 class PyOpAttributeMap {
2107 public:
2108   PyOpAttributeMap(PyOperationRef operation)
2109       : operation(std::move(operation)) {}
2110 
2111   PyAttribute dunderGetItemNamed(const std::string &name) {
2112     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
2113                                                          toMlirStringRef(name));
2114     if (mlirAttributeIsNull(attr)) {
2115       throw SetPyError(PyExc_KeyError,
2116                        "attempt to access a non-existent attribute");
2117     }
2118     return PyAttribute(operation->getContext(), attr);
2119   }
2120 
2121   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
2122     if (index < 0 || index >= dunderLen()) {
2123       throw SetPyError(PyExc_IndexError,
2124                        "attempt to access out of bounds attribute");
2125     }
2126     MlirNamedAttribute namedAttr =
2127         mlirOperationGetAttribute(operation->get(), index);
2128     return PyNamedAttribute(
2129         namedAttr.attribute,
2130         std::string(mlirIdentifierStr(namedAttr.name).data,
2131                     mlirIdentifierStr(namedAttr.name).length));
2132   }
2133 
2134   void dunderSetItem(const std::string &name, const PyAttribute &attr) {
2135     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
2136                                     attr);
2137   }
2138 
2139   void dunderDelItem(const std::string &name) {
2140     int removed = mlirOperationRemoveAttributeByName(operation->get(),
2141                                                      toMlirStringRef(name));
2142     if (!removed)
2143       throw SetPyError(PyExc_KeyError,
2144                        "attempt to delete a non-existent attribute");
2145   }
2146 
2147   intptr_t dunderLen() {
2148     return mlirOperationGetNumAttributes(operation->get());
2149   }
2150 
2151   bool dunderContains(const std::string &name) {
2152     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
2153         operation->get(), toMlirStringRef(name)));
2154   }
2155 
2156   static void bind(py::module &m) {
2157     py::class_<PyOpAttributeMap>(m, "OpAttributeMap", py::module_local())
2158         .def("__contains__", &PyOpAttributeMap::dunderContains)
2159         .def("__len__", &PyOpAttributeMap::dunderLen)
2160         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
2161         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
2162         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
2163         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
2164   }
2165 
2166 private:
2167   PyOperationRef operation;
2168 };
2169 
2170 } // namespace
2171 
2172 //------------------------------------------------------------------------------
2173 // Populates the core exports of the 'ir' submodule.
2174 //------------------------------------------------------------------------------
2175 
2176 void mlir::python::populateIRCore(py::module &m) {
2177   //----------------------------------------------------------------------------
2178   // Enums.
2179   //----------------------------------------------------------------------------
2180   py::enum_<MlirDiagnosticSeverity>(m, "DiagnosticSeverity", py::module_local())
2181       .value("ERROR", MlirDiagnosticError)
2182       .value("WARNING", MlirDiagnosticWarning)
2183       .value("NOTE", MlirDiagnosticNote)
2184       .value("REMARK", MlirDiagnosticRemark);
2185 
2186   //----------------------------------------------------------------------------
2187   // Mapping of Diagnostics.
2188   //----------------------------------------------------------------------------
2189   py::class_<PyDiagnostic>(m, "Diagnostic", py::module_local())
2190       .def_property_readonly("severity", &PyDiagnostic::getSeverity)
2191       .def_property_readonly("location", &PyDiagnostic::getLocation)
2192       .def_property_readonly("message", &PyDiagnostic::getMessage)
2193       .def_property_readonly("notes", &PyDiagnostic::getNotes)
2194       .def("__str__", [](PyDiagnostic &self) -> py::str {
2195         if (!self.isValid())
2196           return "<Invalid Diagnostic>";
2197         return self.getMessage();
2198       });
2199 
2200   py::class_<PyDiagnosticHandler>(m, "DiagnosticHandler", py::module_local())
2201       .def("detach", &PyDiagnosticHandler::detach)
2202       .def_property_readonly("attached", &PyDiagnosticHandler::isAttached)
2203       .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError)
2204       .def("__enter__", &PyDiagnosticHandler::contextEnter)
2205       .def("__exit__", &PyDiagnosticHandler::contextExit);
2206 
2207   //----------------------------------------------------------------------------
2208   // Mapping of MlirContext.
2209   //----------------------------------------------------------------------------
2210   py::class_<PyMlirContext>(m, "Context", py::module_local())
2211       .def(py::init<>(&PyMlirContext::createNewContextForInit))
2212       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2213       .def("_get_context_again",
2214            [](PyMlirContext &self) {
2215              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2216              return ref.releaseObject();
2217            })
2218       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2219       .def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
2220       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2221       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2222                              &PyMlirContext::getCapsule)
2223       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2224       .def("__enter__", &PyMlirContext::contextEnter)
2225       .def("__exit__", &PyMlirContext::contextExit)
2226       .def_property_readonly_static(
2227           "current",
2228           [](py::object & /*class*/) {
2229             auto *context = PyThreadContextEntry::getDefaultContext();
2230             if (!context)
2231               throw SetPyError(PyExc_ValueError, "No current Context");
2232             return context;
2233           },
2234           "Gets the Context bound to the current thread or raises ValueError")
2235       .def_property_readonly(
2236           "dialects",
2237           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2238           "Gets a container for accessing dialects by name")
2239       .def_property_readonly(
2240           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2241           "Alias for 'dialect'")
2242       .def(
2243           "get_dialect_descriptor",
2244           [=](PyMlirContext &self, std::string &name) {
2245             MlirDialect dialect = mlirContextGetOrLoadDialect(
2246                 self.get(), {name.data(), name.size()});
2247             if (mlirDialectIsNull(dialect)) {
2248               throw SetPyError(PyExc_ValueError,
2249                                Twine("Dialect '") + name + "' not found");
2250             }
2251             return PyDialectDescriptor(self.getRef(), dialect);
2252           },
2253           py::arg("dialect_name"),
2254           "Gets or loads a dialect by name, returning its descriptor object")
2255       .def_property(
2256           "allow_unregistered_dialects",
2257           [](PyMlirContext &self) -> bool {
2258             return mlirContextGetAllowUnregisteredDialects(self.get());
2259           },
2260           [](PyMlirContext &self, bool value) {
2261             mlirContextSetAllowUnregisteredDialects(self.get(), value);
2262           })
2263       .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler,
2264            py::arg("callback"),
2265            "Attaches a diagnostic handler that will receive callbacks")
2266       .def(
2267           "enable_multithreading",
2268           [](PyMlirContext &self, bool enable) {
2269             mlirContextEnableMultithreading(self.get(), enable);
2270           },
2271           py::arg("enable"))
2272       .def(
2273           "is_registered_operation",
2274           [](PyMlirContext &self, std::string &name) {
2275             return mlirContextIsRegisteredOperation(
2276                 self.get(), MlirStringRef{name.data(), name.size()});
2277           },
2278           py::arg("operation_name"));
2279 
2280   //----------------------------------------------------------------------------
2281   // Mapping of PyDialectDescriptor
2282   //----------------------------------------------------------------------------
2283   py::class_<PyDialectDescriptor>(m, "DialectDescriptor", py::module_local())
2284       .def_property_readonly("namespace",
2285                              [](PyDialectDescriptor &self) {
2286                                MlirStringRef ns =
2287                                    mlirDialectGetNamespace(self.get());
2288                                return py::str(ns.data, ns.length);
2289                              })
2290       .def("__repr__", [](PyDialectDescriptor &self) {
2291         MlirStringRef ns = mlirDialectGetNamespace(self.get());
2292         std::string repr("<DialectDescriptor ");
2293         repr.append(ns.data, ns.length);
2294         repr.append(">");
2295         return repr;
2296       });
2297 
2298   //----------------------------------------------------------------------------
2299   // Mapping of PyDialects
2300   //----------------------------------------------------------------------------
2301   py::class_<PyDialects>(m, "Dialects", py::module_local())
2302       .def("__getitem__",
2303            [=](PyDialects &self, std::string keyName) {
2304              MlirDialect dialect =
2305                  self.getDialectForKey(keyName, /*attrError=*/false);
2306              py::object descriptor =
2307                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
2308              return createCustomDialectWrapper(keyName, std::move(descriptor));
2309            })
2310       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2311         MlirDialect dialect =
2312             self.getDialectForKey(attrName, /*attrError=*/true);
2313         py::object descriptor =
2314             py::cast(PyDialectDescriptor{self.getContext(), dialect});
2315         return createCustomDialectWrapper(attrName, std::move(descriptor));
2316       });
2317 
2318   //----------------------------------------------------------------------------
2319   // Mapping of PyDialect
2320   //----------------------------------------------------------------------------
2321   py::class_<PyDialect>(m, "Dialect", py::module_local())
2322       .def(py::init<py::object>(), py::arg("descriptor"))
2323       .def_property_readonly(
2324           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2325       .def("__repr__", [](py::object self) {
2326         auto clazz = self.attr("__class__");
2327         return py::str("<Dialect ") +
2328                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2329                clazz.attr("__module__") + py::str(".") +
2330                clazz.attr("__name__") + py::str(")>");
2331       });
2332 
2333   //----------------------------------------------------------------------------
2334   // Mapping of Location
2335   //----------------------------------------------------------------------------
2336   py::class_<PyLocation>(m, "Location", py::module_local())
2337       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2338       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2339       .def("__enter__", &PyLocation::contextEnter)
2340       .def("__exit__", &PyLocation::contextExit)
2341       .def("__eq__",
2342            [](PyLocation &self, PyLocation &other) -> bool {
2343              return mlirLocationEqual(self, other);
2344            })
2345       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2346       .def_property_readonly_static(
2347           "current",
2348           [](py::object & /*class*/) {
2349             auto *loc = PyThreadContextEntry::getDefaultLocation();
2350             if (!loc)
2351               throw SetPyError(PyExc_ValueError, "No current Location");
2352             return loc;
2353           },
2354           "Gets the Location bound to the current thread or raises ValueError")
2355       .def_static(
2356           "unknown",
2357           [](DefaultingPyMlirContext context) {
2358             return PyLocation(context->getRef(),
2359                               mlirLocationUnknownGet(context->get()));
2360           },
2361           py::arg("context") = py::none(),
2362           "Gets a Location representing an unknown location")
2363       .def_static(
2364           "callsite",
2365           [](PyLocation callee, const std::vector<PyLocation> &frames,
2366              DefaultingPyMlirContext context) {
2367             if (frames.empty())
2368               throw py::value_error("No caller frames provided");
2369             MlirLocation caller = frames.back().get();
2370             for (const PyLocation &frame :
2371                  llvm::reverse(llvm::makeArrayRef(frames).drop_back()))
2372               caller = mlirLocationCallSiteGet(frame.get(), caller);
2373             return PyLocation(context->getRef(),
2374                               mlirLocationCallSiteGet(callee.get(), caller));
2375           },
2376           py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(),
2377           kContextGetCallSiteLocationDocstring)
2378       .def_static(
2379           "file",
2380           [](std::string filename, int line, int col,
2381              DefaultingPyMlirContext context) {
2382             return PyLocation(
2383                 context->getRef(),
2384                 mlirLocationFileLineColGet(
2385                     context->get(), toMlirStringRef(filename), line, col));
2386           },
2387           py::arg("filename"), py::arg("line"), py::arg("col"),
2388           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2389       .def_static(
2390           "fused",
2391           [](const std::vector<PyLocation> &pyLocations,
2392              llvm::Optional<PyAttribute> metadata,
2393              DefaultingPyMlirContext context) {
2394             llvm::SmallVector<MlirLocation, 4> locations;
2395             locations.reserve(pyLocations.size());
2396             for (auto &pyLocation : pyLocations)
2397               locations.push_back(pyLocation.get());
2398             MlirLocation location = mlirLocationFusedGet(
2399                 context->get(), locations.size(), locations.data(),
2400                 metadata ? metadata->get() : MlirAttribute{0});
2401             return PyLocation(context->getRef(), location);
2402           },
2403           py::arg("locations"), py::arg("metadata") = py::none(),
2404           py::arg("context") = py::none(), kContextGetFusedLocationDocstring)
2405       .def_static(
2406           "name",
2407           [](std::string name, llvm::Optional<PyLocation> childLoc,
2408              DefaultingPyMlirContext context) {
2409             return PyLocation(
2410                 context->getRef(),
2411                 mlirLocationNameGet(
2412                     context->get(), toMlirStringRef(name),
2413                     childLoc ? childLoc->get()
2414                              : mlirLocationUnknownGet(context->get())));
2415           },
2416           py::arg("name"), py::arg("childLoc") = py::none(),
2417           py::arg("context") = py::none(), kContextGetNameLocationDocString)
2418       .def_property_readonly(
2419           "context",
2420           [](PyLocation &self) { return self.getContext().getObject(); },
2421           "Context that owns the Location")
2422       .def(
2423           "emit_error",
2424           [](PyLocation &self, std::string message) {
2425             mlirEmitError(self, message.c_str());
2426           },
2427           py::arg("message"), "Emits an error at this location")
2428       .def("__repr__", [](PyLocation &self) {
2429         PyPrintAccumulator printAccum;
2430         mlirLocationPrint(self, printAccum.getCallback(),
2431                           printAccum.getUserData());
2432         return printAccum.join();
2433       });
2434 
2435   //----------------------------------------------------------------------------
2436   // Mapping of Module
2437   //----------------------------------------------------------------------------
2438   py::class_<PyModule>(m, "Module", py::module_local())
2439       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2440       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2441       .def_static(
2442           "parse",
2443           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2444             MlirModule module = mlirModuleCreateParse(
2445                 context->get(), toMlirStringRef(moduleAsm));
2446             // TODO: Rework error reporting once diagnostic engine is exposed
2447             // in C API.
2448             if (mlirModuleIsNull(module)) {
2449               throw SetPyError(
2450                   PyExc_ValueError,
2451                   "Unable to parse module assembly (see diagnostics)");
2452             }
2453             return PyModule::forModule(module).releaseObject();
2454           },
2455           py::arg("asm"), py::arg("context") = py::none(),
2456           kModuleParseDocstring)
2457       .def_static(
2458           "create",
2459           [](DefaultingPyLocation loc) {
2460             MlirModule module = mlirModuleCreateEmpty(loc);
2461             return PyModule::forModule(module).releaseObject();
2462           },
2463           py::arg("loc") = py::none(), "Creates an empty module")
2464       .def_property_readonly(
2465           "context",
2466           [](PyModule &self) { return self.getContext().getObject(); },
2467           "Context that created the Module")
2468       .def_property_readonly(
2469           "operation",
2470           [](PyModule &self) {
2471             return PyOperation::forOperation(self.getContext(),
2472                                              mlirModuleGetOperation(self.get()),
2473                                              self.getRef().releaseObject())
2474                 .releaseObject();
2475           },
2476           "Accesses the module as an operation")
2477       .def_property_readonly(
2478           "body",
2479           [](PyModule &self) {
2480             PyOperationRef moduleOp = PyOperation::forOperation(
2481                 self.getContext(), mlirModuleGetOperation(self.get()),
2482                 self.getRef().releaseObject());
2483             PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get()));
2484             return returnBlock;
2485           },
2486           "Return the block for this module")
2487       .def(
2488           "dump",
2489           [](PyModule &self) {
2490             mlirOperationDump(mlirModuleGetOperation(self.get()));
2491           },
2492           kDumpDocstring)
2493       .def(
2494           "__str__",
2495           [](py::object self) {
2496             // Defer to the operation's __str__.
2497             return self.attr("operation").attr("__str__")();
2498           },
2499           kOperationStrDunderDocstring);
2500 
2501   //----------------------------------------------------------------------------
2502   // Mapping of Operation.
2503   //----------------------------------------------------------------------------
2504   py::class_<PyOperationBase>(m, "_OperationBase", py::module_local())
2505       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2506                              [](PyOperationBase &self) {
2507                                return self.getOperation().getCapsule();
2508                              })
2509       .def("__eq__",
2510            [](PyOperationBase &self, PyOperationBase &other) {
2511              return &self.getOperation() == &other.getOperation();
2512            })
2513       .def("__eq__",
2514            [](PyOperationBase &self, py::object other) { return false; })
2515       .def("__hash__",
2516            [](PyOperationBase &self) {
2517              return static_cast<size_t>(llvm::hash_value(&self.getOperation()));
2518            })
2519       .def_property_readonly("attributes",
2520                              [](PyOperationBase &self) {
2521                                return PyOpAttributeMap(
2522                                    self.getOperation().getRef());
2523                              })
2524       .def_property_readonly("operands",
2525                              [](PyOperationBase &self) {
2526                                return PyOpOperandList(
2527                                    self.getOperation().getRef());
2528                              })
2529       .def_property_readonly("regions",
2530                              [](PyOperationBase &self) {
2531                                return PyRegionList(
2532                                    self.getOperation().getRef());
2533                              })
2534       .def_property_readonly(
2535           "results",
2536           [](PyOperationBase &self) {
2537             return PyOpResultList(self.getOperation().getRef());
2538           },
2539           "Returns the list of Operation results.")
2540       .def_property_readonly(
2541           "result",
2542           [](PyOperationBase &self) {
2543             auto &operation = self.getOperation();
2544             auto numResults = mlirOperationGetNumResults(operation);
2545             if (numResults != 1) {
2546               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2547               throw SetPyError(
2548                   PyExc_ValueError,
2549                   Twine("Cannot call .result on operation ") +
2550                       StringRef(name.data, name.length) + " which has " +
2551                       Twine(numResults) +
2552                       " results (it is only valid for operations with a "
2553                       "single result)");
2554             }
2555             return PyOpResult(operation.getRef(),
2556                               mlirOperationGetResult(operation, 0));
2557           },
2558           "Shortcut to get an op result if it has only one (throws an error "
2559           "otherwise).")
2560       .def_property_readonly(
2561           "location",
2562           [](PyOperationBase &self) {
2563             PyOperation &operation = self.getOperation();
2564             return PyLocation(operation.getContext(),
2565                               mlirOperationGetLocation(operation.get()));
2566           },
2567           "Returns the source location the operation was defined or derived "
2568           "from.")
2569       .def(
2570           "__str__",
2571           [](PyOperationBase &self) {
2572             return self.getAsm(/*binary=*/false,
2573                                /*largeElementsLimit=*/llvm::None,
2574                                /*enableDebugInfo=*/false,
2575                                /*prettyDebugInfo=*/false,
2576                                /*printGenericOpForm=*/false,
2577                                /*useLocalScope=*/false,
2578                                /*assumeVerified=*/false);
2579           },
2580           "Returns the assembly form of the operation.")
2581       .def("print", &PyOperationBase::print,
2582            // Careful: Lots of arguments must match up with print method.
2583            py::arg("file") = py::none(), py::arg("binary") = false,
2584            py::arg("large_elements_limit") = py::none(),
2585            py::arg("enable_debug_info") = false,
2586            py::arg("pretty_debug_info") = false,
2587            py::arg("print_generic_op_form") = false,
2588            py::arg("use_local_scope") = false,
2589            py::arg("assume_verified") = false, kOperationPrintDocstring)
2590       .def("get_asm", &PyOperationBase::getAsm,
2591            // Careful: Lots of arguments must match up with get_asm method.
2592            py::arg("binary") = false,
2593            py::arg("large_elements_limit") = py::none(),
2594            py::arg("enable_debug_info") = false,
2595            py::arg("pretty_debug_info") = false,
2596            py::arg("print_generic_op_form") = false,
2597            py::arg("use_local_scope") = false,
2598            py::arg("assume_verified") = false, kOperationGetAsmDocstring)
2599       .def(
2600           "verify",
2601           [](PyOperationBase &self) {
2602             return mlirOperationVerify(self.getOperation());
2603           },
2604           "Verify the operation and return true if it passes, false if it "
2605           "fails.")
2606       .def("move_after", &PyOperationBase::moveAfter, py::arg("other"),
2607            "Puts self immediately after the other operation in its parent "
2608            "block.")
2609       .def("move_before", &PyOperationBase::moveBefore, py::arg("other"),
2610            "Puts self immediately before the other operation in its parent "
2611            "block.")
2612       .def(
2613           "detach_from_parent",
2614           [](PyOperationBase &self) {
2615             PyOperation &operation = self.getOperation();
2616             operation.checkValid();
2617             if (!operation.isAttached())
2618               throw py::value_error("Detached operation has no parent.");
2619 
2620             operation.detachFromParent();
2621             return operation.createOpView();
2622           },
2623           "Detaches the operation from its parent block.");
2624 
2625   py::class_<PyOperation, PyOperationBase>(m, "Operation", py::module_local())
2626       .def_static("create", &PyOperation::create, py::arg("name"),
2627                   py::arg("results") = py::none(),
2628                   py::arg("operands") = py::none(),
2629                   py::arg("attributes") = py::none(),
2630                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2631                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2632                   kOperationCreateDocstring)
2633       .def_property_readonly("parent",
2634                              [](PyOperation &self) -> py::object {
2635                                auto parent = self.getParentOperation();
2636                                if (parent)
2637                                  return parent->getObject();
2638                                return py::none();
2639                              })
2640       .def("erase", &PyOperation::erase)
2641       .def("clone", &PyOperation::clone, py::arg("ip") = py::none())
2642       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2643                              &PyOperation::getCapsule)
2644       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2645       .def_property_readonly("name",
2646                              [](PyOperation &self) {
2647                                self.checkValid();
2648                                MlirOperation operation = self.get();
2649                                MlirStringRef name = mlirIdentifierStr(
2650                                    mlirOperationGetName(operation));
2651                                return py::str(name.data, name.length);
2652                              })
2653       .def_property_readonly(
2654           "context",
2655           [](PyOperation &self) {
2656             self.checkValid();
2657             return self.getContext().getObject();
2658           },
2659           "Context that owns the Operation")
2660       .def_property_readonly("opview", &PyOperation::createOpView);
2661 
2662   auto opViewClass =
2663       py::class_<PyOpView, PyOperationBase>(m, "OpView", py::module_local())
2664           .def(py::init<py::object>(), py::arg("operation"))
2665           .def_property_readonly("operation", &PyOpView::getOperationObject)
2666           .def_property_readonly(
2667               "context",
2668               [](PyOpView &self) {
2669                 return self.getOperation().getContext().getObject();
2670               },
2671               "Context that owns the Operation")
2672           .def("__str__", [](PyOpView &self) {
2673             return py::str(self.getOperationObject());
2674           });
2675   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2676   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2677   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2678   opViewClass.attr("build_generic") = classmethod(
2679       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2680       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2681       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2682       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2683       "Builds a specific, generated OpView based on class level attributes.");
2684 
2685   //----------------------------------------------------------------------------
2686   // Mapping of PyRegion.
2687   //----------------------------------------------------------------------------
2688   py::class_<PyRegion>(m, "Region", py::module_local())
2689       .def_property_readonly(
2690           "blocks",
2691           [](PyRegion &self) {
2692             return PyBlockList(self.getParentOperation(), self.get());
2693           },
2694           "Returns a forward-optimized sequence of blocks.")
2695       .def_property_readonly(
2696           "owner",
2697           [](PyRegion &self) {
2698             return self.getParentOperation()->createOpView();
2699           },
2700           "Returns the operation owning this region.")
2701       .def(
2702           "__iter__",
2703           [](PyRegion &self) {
2704             self.checkValid();
2705             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2706             return PyBlockIterator(self.getParentOperation(), firstBlock);
2707           },
2708           "Iterates over blocks in the region.")
2709       .def("__eq__",
2710            [](PyRegion &self, PyRegion &other) {
2711              return self.get().ptr == other.get().ptr;
2712            })
2713       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2714 
2715   //----------------------------------------------------------------------------
2716   // Mapping of PyBlock.
2717   //----------------------------------------------------------------------------
2718   py::class_<PyBlock>(m, "Block", py::module_local())
2719       .def_property_readonly(
2720           "owner",
2721           [](PyBlock &self) {
2722             return self.getParentOperation()->createOpView();
2723           },
2724           "Returns the owning operation of this block.")
2725       .def_property_readonly(
2726           "region",
2727           [](PyBlock &self) {
2728             MlirRegion region = mlirBlockGetParentRegion(self.get());
2729             return PyRegion(self.getParentOperation(), region);
2730           },
2731           "Returns the owning region of this block.")
2732       .def_property_readonly(
2733           "arguments",
2734           [](PyBlock &self) {
2735             return PyBlockArgumentList(self.getParentOperation(), self.get());
2736           },
2737           "Returns a list of block arguments.")
2738       .def_property_readonly(
2739           "operations",
2740           [](PyBlock &self) {
2741             return PyOperationList(self.getParentOperation(), self.get());
2742           },
2743           "Returns a forward-optimized sequence of operations.")
2744       .def_static(
2745           "create_at_start",
2746           [](PyRegion &parent, py::list pyArgTypes) {
2747             parent.checkValid();
2748             llvm::SmallVector<MlirType, 4> argTypes;
2749             llvm::SmallVector<MlirLocation, 4> argLocs;
2750             argTypes.reserve(pyArgTypes.size());
2751             argLocs.reserve(pyArgTypes.size());
2752             for (auto &pyArg : pyArgTypes) {
2753               argTypes.push_back(pyArg.cast<PyType &>());
2754               // TODO: Pass in a proper location here.
2755               argLocs.push_back(
2756                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2757             }
2758 
2759             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2760                                               argLocs.data());
2761             mlirRegionInsertOwnedBlock(parent, 0, block);
2762             return PyBlock(parent.getParentOperation(), block);
2763           },
2764           py::arg("parent"), py::arg("arg_types") = py::list(),
2765           "Creates and returns a new Block at the beginning of the given "
2766           "region (with given argument types).")
2767       .def(
2768           "append_to",
2769           [](PyBlock &self, PyRegion &region) {
2770             MlirBlock b = self.get();
2771             if (!mlirRegionIsNull(mlirBlockGetParentRegion(b)))
2772               mlirBlockDetach(b);
2773             mlirRegionAppendOwnedBlock(region.get(), b);
2774           },
2775           "Append this block to a region, transferring ownership if necessary")
2776       .def(
2777           "create_before",
2778           [](PyBlock &self, py::args pyArgTypes) {
2779             self.checkValid();
2780             llvm::SmallVector<MlirType, 4> argTypes;
2781             llvm::SmallVector<MlirLocation, 4> argLocs;
2782             argTypes.reserve(pyArgTypes.size());
2783             argLocs.reserve(pyArgTypes.size());
2784             for (auto &pyArg : pyArgTypes) {
2785               argTypes.push_back(pyArg.cast<PyType &>());
2786               // TODO: Pass in a proper location here.
2787               argLocs.push_back(
2788                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2789             }
2790 
2791             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2792                                               argLocs.data());
2793             MlirRegion region = mlirBlockGetParentRegion(self.get());
2794             mlirRegionInsertOwnedBlockBefore(region, self.get(), block);
2795             return PyBlock(self.getParentOperation(), block);
2796           },
2797           "Creates and returns a new Block before this block "
2798           "(with given argument types).")
2799       .def(
2800           "create_after",
2801           [](PyBlock &self, py::args pyArgTypes) {
2802             self.checkValid();
2803             llvm::SmallVector<MlirType, 4> argTypes;
2804             llvm::SmallVector<MlirLocation, 4> argLocs;
2805             argTypes.reserve(pyArgTypes.size());
2806             argLocs.reserve(pyArgTypes.size());
2807             for (auto &pyArg : pyArgTypes) {
2808               argTypes.push_back(pyArg.cast<PyType &>());
2809 
2810               // TODO: Pass in a proper location here.
2811               argLocs.push_back(
2812                   mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back())));
2813             }
2814             MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(),
2815                                               argLocs.data());
2816             MlirRegion region = mlirBlockGetParentRegion(self.get());
2817             mlirRegionInsertOwnedBlockAfter(region, self.get(), block);
2818             return PyBlock(self.getParentOperation(), block);
2819           },
2820           "Creates and returns a new Block after this block "
2821           "(with given argument types).")
2822       .def(
2823           "__iter__",
2824           [](PyBlock &self) {
2825             self.checkValid();
2826             MlirOperation firstOperation =
2827                 mlirBlockGetFirstOperation(self.get());
2828             return PyOperationIterator(self.getParentOperation(),
2829                                        firstOperation);
2830           },
2831           "Iterates over operations in the block.")
2832       .def("__eq__",
2833            [](PyBlock &self, PyBlock &other) {
2834              return self.get().ptr == other.get().ptr;
2835            })
2836       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2837       .def(
2838           "__str__",
2839           [](PyBlock &self) {
2840             self.checkValid();
2841             PyPrintAccumulator printAccum;
2842             mlirBlockPrint(self.get(), printAccum.getCallback(),
2843                            printAccum.getUserData());
2844             return printAccum.join();
2845           },
2846           "Returns the assembly form of the block.")
2847       .def(
2848           "append",
2849           [](PyBlock &self, PyOperationBase &operation) {
2850             if (operation.getOperation().isAttached())
2851               operation.getOperation().detachFromParent();
2852 
2853             MlirOperation mlirOperation = operation.getOperation().get();
2854             mlirBlockAppendOwnedOperation(self.get(), mlirOperation);
2855             operation.getOperation().setAttached(
2856                 self.getParentOperation().getObject());
2857           },
2858           py::arg("operation"),
2859           "Appends an operation to this block. If the operation is currently "
2860           "in another block, it will be moved.");
2861 
2862   //----------------------------------------------------------------------------
2863   // Mapping of PyInsertionPoint.
2864   //----------------------------------------------------------------------------
2865 
2866   py::class_<PyInsertionPoint>(m, "InsertionPoint", py::module_local())
2867       .def(py::init<PyBlock &>(), py::arg("block"),
2868            "Inserts after the last operation but still inside the block.")
2869       .def("__enter__", &PyInsertionPoint::contextEnter)
2870       .def("__exit__", &PyInsertionPoint::contextExit)
2871       .def_property_readonly_static(
2872           "current",
2873           [](py::object & /*class*/) {
2874             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2875             if (!ip)
2876               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2877             return ip;
2878           },
2879           "Gets the InsertionPoint bound to the current thread or raises "
2880           "ValueError if none has been set")
2881       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2882            "Inserts before a referenced operation.")
2883       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2884                   py::arg("block"), "Inserts at the beginning of the block.")
2885       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2886                   py::arg("block"), "Inserts before the block terminator.")
2887       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2888            "Inserts an operation.")
2889       .def_property_readonly(
2890           "block", [](PyInsertionPoint &self) { return self.getBlock(); },
2891           "Returns the block that this InsertionPoint points to.");
2892 
2893   //----------------------------------------------------------------------------
2894   // Mapping of PyAttribute.
2895   //----------------------------------------------------------------------------
2896   py::class_<PyAttribute>(m, "Attribute", py::module_local())
2897       // Delegate to the PyAttribute copy constructor, which will also lifetime
2898       // extend the backing context which owns the MlirAttribute.
2899       .def(py::init<PyAttribute &>(), py::arg("cast_from_type"),
2900            "Casts the passed attribute to the generic Attribute")
2901       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2902                              &PyAttribute::getCapsule)
2903       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2904       .def_static(
2905           "parse",
2906           [](std::string attrSpec, DefaultingPyMlirContext context) {
2907             MlirAttribute type = mlirAttributeParseGet(
2908                 context->get(), toMlirStringRef(attrSpec));
2909             // TODO: Rework error reporting once diagnostic engine is exposed
2910             // in C API.
2911             if (mlirAttributeIsNull(type)) {
2912               throw SetPyError(PyExc_ValueError,
2913                                Twine("Unable to parse attribute: '") +
2914                                    attrSpec + "'");
2915             }
2916             return PyAttribute(context->getRef(), type);
2917           },
2918           py::arg("asm"), py::arg("context") = py::none(),
2919           "Parses an attribute from an assembly form")
2920       .def_property_readonly(
2921           "context",
2922           [](PyAttribute &self) { return self.getContext().getObject(); },
2923           "Context that owns the Attribute")
2924       .def_property_readonly("type",
2925                              [](PyAttribute &self) {
2926                                return PyType(self.getContext()->getRef(),
2927                                              mlirAttributeGetType(self));
2928                              })
2929       .def(
2930           "get_named",
2931           [](PyAttribute &self, std::string name) {
2932             return PyNamedAttribute(self, std::move(name));
2933           },
2934           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2935       .def("__eq__",
2936            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2937       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2938       .def("__hash__",
2939            [](PyAttribute &self) {
2940              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
2941            })
2942       .def(
2943           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2944           kDumpDocstring)
2945       .def(
2946           "__str__",
2947           [](PyAttribute &self) {
2948             PyPrintAccumulator printAccum;
2949             mlirAttributePrint(self, printAccum.getCallback(),
2950                                printAccum.getUserData());
2951             return printAccum.join();
2952           },
2953           "Returns the assembly form of the Attribute.")
2954       .def("__repr__", [](PyAttribute &self) {
2955         // Generally, assembly formats are not printed for __repr__ because
2956         // this can cause exceptionally long debug output and exceptions.
2957         // However, attribute values are generally considered useful and are
2958         // printed. This may need to be re-evaluated if debug dumps end up
2959         // being excessive.
2960         PyPrintAccumulator printAccum;
2961         printAccum.parts.append("Attribute(");
2962         mlirAttributePrint(self, printAccum.getCallback(),
2963                            printAccum.getUserData());
2964         printAccum.parts.append(")");
2965         return printAccum.join();
2966       });
2967 
2968   //----------------------------------------------------------------------------
2969   // Mapping of PyNamedAttribute
2970   //----------------------------------------------------------------------------
2971   py::class_<PyNamedAttribute>(m, "NamedAttribute", py::module_local())
2972       .def("__repr__",
2973            [](PyNamedAttribute &self) {
2974              PyPrintAccumulator printAccum;
2975              printAccum.parts.append("NamedAttribute(");
2976              printAccum.parts.append(
2977                  py::str(mlirIdentifierStr(self.namedAttr.name).data,
2978                          mlirIdentifierStr(self.namedAttr.name).length));
2979              printAccum.parts.append("=");
2980              mlirAttributePrint(self.namedAttr.attribute,
2981                                 printAccum.getCallback(),
2982                                 printAccum.getUserData());
2983              printAccum.parts.append(")");
2984              return printAccum.join();
2985            })
2986       .def_property_readonly(
2987           "name",
2988           [](PyNamedAttribute &self) {
2989             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2990                            mlirIdentifierStr(self.namedAttr.name).length);
2991           },
2992           "The name of the NamedAttribute binding")
2993       .def_property_readonly(
2994           "attr",
2995           [](PyNamedAttribute &self) {
2996             // TODO: When named attribute is removed/refactored, also remove
2997             // this constructor (it does an inefficient table lookup).
2998             auto contextRef = PyMlirContext::forContext(
2999                 mlirAttributeGetContext(self.namedAttr.attribute));
3000             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3001           },
3002           py::keep_alive<0, 1>(),
3003           "The underlying generic attribute of the NamedAttribute binding");
3004 
3005   //----------------------------------------------------------------------------
3006   // Mapping of PyType.
3007   //----------------------------------------------------------------------------
3008   py::class_<PyType>(m, "Type", py::module_local())
3009       // Delegate to the PyType copy constructor, which will also lifetime
3010       // extend the backing context which owns the MlirType.
3011       .def(py::init<PyType &>(), py::arg("cast_from_type"),
3012            "Casts the passed type to the generic Type")
3013       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3014       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3015       .def_static(
3016           "parse",
3017           [](std::string typeSpec, DefaultingPyMlirContext context) {
3018             MlirType type =
3019                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3020             // TODO: Rework error reporting once diagnostic engine is exposed
3021             // in C API.
3022             if (mlirTypeIsNull(type)) {
3023               throw SetPyError(PyExc_ValueError,
3024                                Twine("Unable to parse type: '") + typeSpec +
3025                                    "'");
3026             }
3027             return PyType(context->getRef(), type);
3028           },
3029           py::arg("asm"), py::arg("context") = py::none(),
3030           kContextParseTypeDocstring)
3031       .def_property_readonly(
3032           "context", [](PyType &self) { return self.getContext().getObject(); },
3033           "Context that owns the Type")
3034       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3035       .def("__eq__", [](PyType &self, py::object &other) { return false; })
3036       .def("__hash__",
3037            [](PyType &self) {
3038              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3039            })
3040       .def(
3041           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3042       .def(
3043           "__str__",
3044           [](PyType &self) {
3045             PyPrintAccumulator printAccum;
3046             mlirTypePrint(self, printAccum.getCallback(),
3047                           printAccum.getUserData());
3048             return printAccum.join();
3049           },
3050           "Returns the assembly form of the type.")
3051       .def("__repr__", [](PyType &self) {
3052         // Generally, assembly formats are not printed for __repr__ because
3053         // this can cause exceptionally long debug output and exceptions.
3054         // However, types are an exception as they typically have compact
3055         // assembly forms and printing them is useful.
3056         PyPrintAccumulator printAccum;
3057         printAccum.parts.append("Type(");
3058         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3059         printAccum.parts.append(")");
3060         return printAccum.join();
3061       });
3062 
3063   //----------------------------------------------------------------------------
3064   // Mapping of Value.
3065   //----------------------------------------------------------------------------
3066   py::class_<PyValue>(m, "Value", py::module_local())
3067       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule)
3068       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule)
3069       .def_property_readonly(
3070           "context",
3071           [](PyValue &self) { return self.getParentOperation()->getContext(); },
3072           "Context in which the value lives.")
3073       .def(
3074           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3075           kDumpDocstring)
3076       .def_property_readonly(
3077           "owner",
3078           [](PyValue &self) {
3079             assert(mlirOperationEqual(self.getParentOperation()->get(),
3080                                       mlirOpResultGetOwner(self.get())) &&
3081                    "expected the owner of the value in Python to match that in "
3082                    "the IR");
3083             return self.getParentOperation().getObject();
3084           })
3085       .def("__eq__",
3086            [](PyValue &self, PyValue &other) {
3087              return self.get().ptr == other.get().ptr;
3088            })
3089       .def("__eq__", [](PyValue &self, py::object other) { return false; })
3090       .def("__hash__",
3091            [](PyValue &self) {
3092              return static_cast<size_t>(llvm::hash_value(self.get().ptr));
3093            })
3094       .def(
3095           "__str__",
3096           [](PyValue &self) {
3097             PyPrintAccumulator printAccum;
3098             printAccum.parts.append("Value(");
3099             mlirValuePrint(self.get(), printAccum.getCallback(),
3100                            printAccum.getUserData());
3101             printAccum.parts.append(")");
3102             return printAccum.join();
3103           },
3104           kValueDunderStrDocstring)
3105       .def_property_readonly("type", [](PyValue &self) {
3106         return PyType(self.getParentOperation()->getContext(),
3107                       mlirValueGetType(self.get()));
3108       });
3109   PyBlockArgument::bind(m);
3110   PyOpResult::bind(m);
3111 
3112   //----------------------------------------------------------------------------
3113   // Mapping of SymbolTable.
3114   //----------------------------------------------------------------------------
3115   py::class_<PySymbolTable>(m, "SymbolTable", py::module_local())
3116       .def(py::init<PyOperationBase &>())
3117       .def("__getitem__", &PySymbolTable::dunderGetItem)
3118       .def("insert", &PySymbolTable::insert, py::arg("operation"))
3119       .def("erase", &PySymbolTable::erase, py::arg("operation"))
3120       .def("__delitem__", &PySymbolTable::dunderDel)
3121       .def("__contains__",
3122            [](PySymbolTable &table, const std::string &name) {
3123              return !mlirOperationIsNull(mlirSymbolTableLookup(
3124                  table, mlirStringRefCreate(name.data(), name.length())));
3125            })
3126       // Static helpers.
3127       .def_static("set_symbol_name", &PySymbolTable::setSymbolName,
3128                   py::arg("symbol"), py::arg("name"))
3129       .def_static("get_symbol_name", &PySymbolTable::getSymbolName,
3130                   py::arg("symbol"))
3131       .def_static("get_visibility", &PySymbolTable::getVisibility,
3132                   py::arg("symbol"))
3133       .def_static("set_visibility", &PySymbolTable::setVisibility,
3134                   py::arg("symbol"), py::arg("visibility"))
3135       .def_static("replace_all_symbol_uses",
3136                   &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"),
3137                   py::arg("new_symbol"), py::arg("from_op"))
3138       .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables,
3139                   py::arg("from_op"), py::arg("all_sym_uses_visible"),
3140                   py::arg("callback"));
3141 
3142   // Container bindings.
3143   PyBlockArgumentList::bind(m);
3144   PyBlockIterator::bind(m);
3145   PyBlockList::bind(m);
3146   PyOperationIterator::bind(m);
3147   PyOperationList::bind(m);
3148   PyOpAttributeMap::bind(m);
3149   PyOpOperandList::bind(m);
3150   PyOpResultList::bind(m);
3151   PyRegionIterator::bind(m);
3152   PyRegionList::bind(m);
3153 
3154   // Debug bindings.
3155   PyGlobalDebugFlag::bind(m);
3156 }
3157