1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Registration.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <pybind11/stl.h>
20 
21 namespace py = pybind11;
22 using namespace mlir;
23 using namespace mlir::python;
24 
25 using llvm::SmallVector;
26 using llvm::StringRef;
27 using llvm::Twine;
28 
29 //------------------------------------------------------------------------------
30 // Docstrings (trivial, non-duplicated docstrings are included inline).
31 //------------------------------------------------------------------------------
32 
33 static const char kContextParseTypeDocstring[] =
34     R"(Parses the assembly form of a type.
35 
36 Returns a Type object or raises a ValueError if the type cannot be parsed.
37 
38 See also: https://mlir.llvm.org/docs/LangRef/#type-system
39 )";
40 
41 static const char kContextGetFileLocationDocstring[] =
42     R"(Gets a Location representing a file, line and column)";
43 
44 static const char kModuleParseDocstring[] =
45     R"(Parses a module's assembly format from a string.
46 
47 Returns a new MlirModule or raises a ValueError if the parsing fails.
48 
49 See also: https://mlir.llvm.org/docs/LangRef/
50 )";
51 
52 static const char kOperationCreateDocstring[] =
53     R"(Creates a new operation.
54 
55 Args:
56   name: Operation name (e.g. "dialect.operation").
57   results: Sequence of Type representing op result types.
58   attributes: Dict of str:Attribute.
59   successors: List of Block for the operation's successors.
60   regions: Number of regions to create.
61   location: A Location object (defaults to resolve from context manager).
62   ip: An InsertionPoint (defaults to resolve from context manager or set to
63     False to disable insertion, even with an insertion point set in the
64     context manager).
65 Returns:
66   A new "detached" Operation object. Detached operations can be added
67   to blocks, which causes them to become "attached."
68 )";
69 
70 static const char kOperationPrintDocstring[] =
71     R"(Prints the assembly form of the operation to a file like object.
72 
73 Args:
74   file: The file like object to write to. Defaults to sys.stdout.
75   binary: Whether to write bytes (True) or str (False). Defaults to False.
76   large_elements_limit: Whether to elide elements attributes above this
77     number of elements. Defaults to None (no limit).
78   enable_debug_info: Whether to print debug/location information. Defaults
79     to False.
80   pretty_debug_info: Whether to format debug information for easier reading
81     by a human (warning: the result is unparseable).
82   print_generic_op_form: Whether to print the generic assembly forms of all
83     ops. Defaults to False.
84   use_local_Scope: Whether to print in a way that is more optimized for
85     multi-threaded access but may not be consistent with how the overall
86     module prints.
87 )";
88 
89 static const char kOperationGetAsmDocstring[] =
90     R"(Gets the assembly form of the operation with all options available.
91 
92 Args:
93   binary: Whether to return a bytes (True) or str (False) object. Defaults to
94     False.
95   ... others ...: See the print() method for common keyword arguments for
96     configuring the printout.
97 Returns:
98   Either a bytes or str object, depending on the setting of the 'binary'
99   argument.
100 )";
101 
102 static const char kOperationStrDunderDocstring[] =
103     R"(Gets the assembly form of the operation with default options.
104 
105 If more advanced control over the assembly formatting or I/O options is needed,
106 use the dedicated print or get_asm method, which supports keyword arguments to
107 customize behavior.
108 )";
109 
110 static const char kDumpDocstring[] =
111     R"(Dumps a debug representation of the object to stderr.)";
112 
113 static const char kAppendBlockDocstring[] =
114     R"(Appends a new block, with argument types as positional args.
115 
116 Returns:
117   The created block.
118 )";
119 
120 static const char kValueDunderStrDocstring[] =
121     R"(Returns the string form of the value.
122 
123 If the value is a block argument, this is the assembly form of its type and the
124 position in the argument list. If the value is an operation result, this is
125 equivalent to printing the operation that produced it.
126 )";
127 
128 //------------------------------------------------------------------------------
129 // Utilities.
130 //------------------------------------------------------------------------------
131 
132 // Helper for creating an @classmethod.
133 template <class Func, typename... Args>
134 py::object classmethod(Func f, Args... args) {
135   py::object cf = py::cpp_function(f, args...);
136   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
137 }
138 
139 static py::object
140 createCustomDialectWrapper(const std::string &dialectNamespace,
141                            py::object dialectDescriptor) {
142   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
143   if (!dialectClass) {
144     // Use the base class.
145     return py::cast(PyDialect(std::move(dialectDescriptor)));
146   }
147 
148   // Create the custom implementation.
149   return (*dialectClass)(std::move(dialectDescriptor));
150 }
151 
152 static MlirStringRef toMlirStringRef(const std::string &s) {
153   return mlirStringRefCreate(s.data(), s.size());
154 }
155 
156 //------------------------------------------------------------------------------
157 // Collections.
158 //------------------------------------------------------------------------------
159 
160 namespace {
161 
162 class PyRegionIterator {
163 public:
164   PyRegionIterator(PyOperationRef operation)
165       : operation(std::move(operation)) {}
166 
167   PyRegionIterator &dunderIter() { return *this; }
168 
169   PyRegion dunderNext() {
170     operation->checkValid();
171     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
172       throw py::stop_iteration();
173     }
174     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
175     return PyRegion(operation, region);
176   }
177 
178   static void bind(py::module &m) {
179     py::class_<PyRegionIterator>(m, "RegionIterator")
180         .def("__iter__", &PyRegionIterator::dunderIter)
181         .def("__next__", &PyRegionIterator::dunderNext);
182   }
183 
184 private:
185   PyOperationRef operation;
186   int nextIndex = 0;
187 };
188 
189 /// Regions of an op are fixed length and indexed numerically so are represented
190 /// with a sequence-like container.
191 class PyRegionList {
192 public:
193   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
194 
195   intptr_t dunderLen() {
196     operation->checkValid();
197     return mlirOperationGetNumRegions(operation->get());
198   }
199 
200   PyRegion dunderGetItem(intptr_t index) {
201     // dunderLen checks validity.
202     if (index < 0 || index >= dunderLen()) {
203       throw SetPyError(PyExc_IndexError,
204                        "attempt to access out of bounds region");
205     }
206     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
207     return PyRegion(operation, region);
208   }
209 
210   static void bind(py::module &m) {
211     py::class_<PyRegionList>(m, "RegionSequence")
212         .def("__len__", &PyRegionList::dunderLen)
213         .def("__getitem__", &PyRegionList::dunderGetItem);
214   }
215 
216 private:
217   PyOperationRef operation;
218 };
219 
220 class PyBlockIterator {
221 public:
222   PyBlockIterator(PyOperationRef operation, MlirBlock next)
223       : operation(std::move(operation)), next(next) {}
224 
225   PyBlockIterator &dunderIter() { return *this; }
226 
227   PyBlock dunderNext() {
228     operation->checkValid();
229     if (mlirBlockIsNull(next)) {
230       throw py::stop_iteration();
231     }
232 
233     PyBlock returnBlock(operation, next);
234     next = mlirBlockGetNextInRegion(next);
235     return returnBlock;
236   }
237 
238   static void bind(py::module &m) {
239     py::class_<PyBlockIterator>(m, "BlockIterator")
240         .def("__iter__", &PyBlockIterator::dunderIter)
241         .def("__next__", &PyBlockIterator::dunderNext);
242   }
243 
244 private:
245   PyOperationRef operation;
246   MlirBlock next;
247 };
248 
249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
250 /// we present them as a more full-featured list-like container but optimize
251 /// it for forward iteration. Blocks are always owned by a region.
252 class PyBlockList {
253 public:
254   PyBlockList(PyOperationRef operation, MlirRegion region)
255       : operation(std::move(operation)), region(region) {}
256 
257   PyBlockIterator dunderIter() {
258     operation->checkValid();
259     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
260   }
261 
262   intptr_t dunderLen() {
263     operation->checkValid();
264     intptr_t count = 0;
265     MlirBlock block = mlirRegionGetFirstBlock(region);
266     while (!mlirBlockIsNull(block)) {
267       count += 1;
268       block = mlirBlockGetNextInRegion(block);
269     }
270     return count;
271   }
272 
273   PyBlock dunderGetItem(intptr_t index) {
274     operation->checkValid();
275     if (index < 0) {
276       throw SetPyError(PyExc_IndexError,
277                        "attempt to access out of bounds block");
278     }
279     MlirBlock block = mlirRegionGetFirstBlock(region);
280     while (!mlirBlockIsNull(block)) {
281       if (index == 0) {
282         return PyBlock(operation, block);
283       }
284       block = mlirBlockGetNextInRegion(block);
285       index -= 1;
286     }
287     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
288   }
289 
290   PyBlock appendBlock(py::args pyArgTypes) {
291     operation->checkValid();
292     llvm::SmallVector<MlirType, 4> argTypes;
293     argTypes.reserve(pyArgTypes.size());
294     for (auto &pyArg : pyArgTypes) {
295       argTypes.push_back(pyArg.cast<PyType &>());
296     }
297 
298     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
299     mlirRegionAppendOwnedBlock(region, block);
300     return PyBlock(operation, block);
301   }
302 
303   static void bind(py::module &m) {
304     py::class_<PyBlockList>(m, "BlockList")
305         .def("__getitem__", &PyBlockList::dunderGetItem)
306         .def("__iter__", &PyBlockList::dunderIter)
307         .def("__len__", &PyBlockList::dunderLen)
308         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
309   }
310 
311 private:
312   PyOperationRef operation;
313   MlirRegion region;
314 };
315 
316 class PyOperationIterator {
317 public:
318   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
319       : parentOperation(std::move(parentOperation)), next(next) {}
320 
321   PyOperationIterator &dunderIter() { return *this; }
322 
323   py::object dunderNext() {
324     parentOperation->checkValid();
325     if (mlirOperationIsNull(next)) {
326       throw py::stop_iteration();
327     }
328 
329     PyOperationRef returnOperation =
330         PyOperation::forOperation(parentOperation->getContext(), next);
331     next = mlirOperationGetNextInBlock(next);
332     return returnOperation->createOpView();
333   }
334 
335   static void bind(py::module &m) {
336     py::class_<PyOperationIterator>(m, "OperationIterator")
337         .def("__iter__", &PyOperationIterator::dunderIter)
338         .def("__next__", &PyOperationIterator::dunderNext);
339   }
340 
341 private:
342   PyOperationRef parentOperation;
343   MlirOperation next;
344 };
345 
346 /// Operations are exposed by the C-API as a forward-only linked list. In
347 /// Python, we present them as a more full-featured list-like container but
348 /// optimize it for forward iteration. Iterable operations are always owned
349 /// by a block.
350 class PyOperationList {
351 public:
352   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
353       : parentOperation(std::move(parentOperation)), block(block) {}
354 
355   PyOperationIterator dunderIter() {
356     parentOperation->checkValid();
357     return PyOperationIterator(parentOperation,
358                                mlirBlockGetFirstOperation(block));
359   }
360 
361   intptr_t dunderLen() {
362     parentOperation->checkValid();
363     intptr_t count = 0;
364     MlirOperation childOp = mlirBlockGetFirstOperation(block);
365     while (!mlirOperationIsNull(childOp)) {
366       count += 1;
367       childOp = mlirOperationGetNextInBlock(childOp);
368     }
369     return count;
370   }
371 
372   py::object dunderGetItem(intptr_t index) {
373     parentOperation->checkValid();
374     if (index < 0) {
375       throw SetPyError(PyExc_IndexError,
376                        "attempt to access out of bounds operation");
377     }
378     MlirOperation childOp = mlirBlockGetFirstOperation(block);
379     while (!mlirOperationIsNull(childOp)) {
380       if (index == 0) {
381         return PyOperation::forOperation(parentOperation->getContext(), childOp)
382             ->createOpView();
383       }
384       childOp = mlirOperationGetNextInBlock(childOp);
385       index -= 1;
386     }
387     throw SetPyError(PyExc_IndexError,
388                      "attempt to access out of bounds operation");
389   }
390 
391   static void bind(py::module &m) {
392     py::class_<PyOperationList>(m, "OperationList")
393         .def("__getitem__", &PyOperationList::dunderGetItem)
394         .def("__iter__", &PyOperationList::dunderIter)
395         .def("__len__", &PyOperationList::dunderLen);
396   }
397 
398 private:
399   PyOperationRef parentOperation;
400   MlirBlock block;
401 };
402 
403 } // namespace
404 
405 //------------------------------------------------------------------------------
406 // PyMlirContext
407 //------------------------------------------------------------------------------
408 
409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
410   py::gil_scoped_acquire acquire;
411   auto &liveContexts = getLiveContexts();
412   liveContexts[context.ptr] = this;
413 }
414 
415 PyMlirContext::~PyMlirContext() {
416   // Note that the only public way to construct an instance is via the
417   // forContext method, which always puts the associated handle into
418   // liveContexts.
419   py::gil_scoped_acquire acquire;
420   getLiveContexts().erase(context.ptr);
421   mlirContextDestroy(context);
422 }
423 
424 py::object PyMlirContext::getCapsule() {
425   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
426 }
427 
428 py::object PyMlirContext::createFromCapsule(py::object capsule) {
429   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
430   if (mlirContextIsNull(rawContext))
431     throw py::error_already_set();
432   return forContext(rawContext).releaseObject();
433 }
434 
435 PyMlirContext *PyMlirContext::createNewContextForInit() {
436   MlirContext context = mlirContextCreate();
437   mlirRegisterAllDialects(context);
438   return new PyMlirContext(context);
439 }
440 
441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
442   py::gil_scoped_acquire acquire;
443   auto &liveContexts = getLiveContexts();
444   auto it = liveContexts.find(context.ptr);
445   if (it == liveContexts.end()) {
446     // Create.
447     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
448     py::object pyRef = py::cast(unownedContextWrapper);
449     assert(pyRef && "cast to py::object failed");
450     liveContexts[context.ptr] = unownedContextWrapper;
451     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
452   }
453   // Use existing.
454   py::object pyRef = py::cast(it->second);
455   return PyMlirContextRef(it->second, std::move(pyRef));
456 }
457 
458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
459   static LiveContextMap liveContexts;
460   return liveContexts;
461 }
462 
463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
464 
465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
466 
467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
468 
469 pybind11::object PyMlirContext::contextEnter() {
470   return PyThreadContextEntry::pushContext(*this);
471 }
472 
473 void PyMlirContext::contextExit(pybind11::object excType,
474                                 pybind11::object excVal,
475                                 pybind11::object excTb) {
476   PyThreadContextEntry::popContext(*this);
477 }
478 
479 PyMlirContext &DefaultingPyMlirContext::resolve() {
480   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
481   if (!context) {
482     throw SetPyError(
483         PyExc_RuntimeError,
484         "An MLIR function requires a Context but none was provided in the call "
485         "or from the surrounding environment. Either pass to the function with "
486         "a 'context=' argument or establish a default using 'with Context():'");
487   }
488   return *context;
489 }
490 
491 //------------------------------------------------------------------------------
492 // PyThreadContextEntry management
493 //------------------------------------------------------------------------------
494 
495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
496   static thread_local std::vector<PyThreadContextEntry> stack;
497   return stack;
498 }
499 
500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
501   auto &stack = getStack();
502   if (stack.empty())
503     return nullptr;
504   return &stack.back();
505 }
506 
507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
508                                 py::object insertionPoint,
509                                 py::object location) {
510   auto &stack = getStack();
511   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
512                      std::move(location));
513   // If the new stack has more than one entry and the context of the new top
514   // entry matches the previous, copy the insertionPoint and location from the
515   // previous entry if missing from the new top entry.
516   if (stack.size() > 1) {
517     auto &prev = *(stack.rbegin() + 1);
518     auto &current = stack.back();
519     if (current.context.is(prev.context)) {
520       // Default non-context objects from the previous entry.
521       if (!current.insertionPoint)
522         current.insertionPoint = prev.insertionPoint;
523       if (!current.location)
524         current.location = prev.location;
525     }
526   }
527 }
528 
529 PyMlirContext *PyThreadContextEntry::getContext() {
530   if (!context)
531     return nullptr;
532   return py::cast<PyMlirContext *>(context);
533 }
534 
535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
536   if (!insertionPoint)
537     return nullptr;
538   return py::cast<PyInsertionPoint *>(insertionPoint);
539 }
540 
541 PyLocation *PyThreadContextEntry::getLocation() {
542   if (!location)
543     return nullptr;
544   return py::cast<PyLocation *>(location);
545 }
546 
547 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
548   auto *tos = getTopOfStack();
549   return tos ? tos->getContext() : nullptr;
550 }
551 
552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
553   auto *tos = getTopOfStack();
554   return tos ? tos->getInsertionPoint() : nullptr;
555 }
556 
557 PyLocation *PyThreadContextEntry::getDefaultLocation() {
558   auto *tos = getTopOfStack();
559   return tos ? tos->getLocation() : nullptr;
560 }
561 
562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
563   py::object contextObj = py::cast(context);
564   push(FrameKind::Context, /*context=*/contextObj,
565        /*insertionPoint=*/py::object(),
566        /*location=*/py::object());
567   return contextObj;
568 }
569 
570 void PyThreadContextEntry::popContext(PyMlirContext &context) {
571   auto &stack = getStack();
572   if (stack.empty())
573     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
574   auto &tos = stack.back();
575   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
576     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
577   stack.pop_back();
578 }
579 
580 py::object
581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
582   py::object contextObj =
583       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
584   py::object insertionPointObj = py::cast(insertionPoint);
585   push(FrameKind::InsertionPoint,
586        /*context=*/contextObj,
587        /*insertionPoint=*/insertionPointObj,
588        /*location=*/py::object());
589   return insertionPointObj;
590 }
591 
592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
593   auto &stack = getStack();
594   if (stack.empty())
595     throw SetPyError(PyExc_RuntimeError,
596                      "Unbalanced InsertionPoint enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::InsertionPoint &&
599       tos.getInsertionPoint() != &insertionPoint)
600     throw SetPyError(PyExc_RuntimeError,
601                      "Unbalanced InsertionPoint enter/exit");
602   stack.pop_back();
603 }
604 
605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
606   py::object contextObj = location.getContext().getObject();
607   py::object locationObj = py::cast(location);
608   push(FrameKind::Location, /*context=*/contextObj,
609        /*insertionPoint=*/py::object(),
610        /*location=*/locationObj);
611   return locationObj;
612 }
613 
614 void PyThreadContextEntry::popLocation(PyLocation &location) {
615   auto &stack = getStack();
616   if (stack.empty())
617     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
618   auto &tos = stack.back();
619   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
620     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
621   stack.pop_back();
622 }
623 
624 //------------------------------------------------------------------------------
625 // PyDialect, PyDialectDescriptor, PyDialects
626 //------------------------------------------------------------------------------
627 
628 MlirDialect PyDialects::getDialectForKey(const std::string &key,
629                                          bool attrError) {
630   // If the "std" dialect was asked for, substitute the empty namespace :(
631   static const std::string emptyKey;
632   const std::string *canonKey = key == "std" ? &emptyKey : &key;
633   MlirDialect dialect = mlirContextGetOrLoadDialect(
634       getContext()->get(), {canonKey->data(), canonKey->size()});
635   if (mlirDialectIsNull(dialect)) {
636     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
637                      Twine("Dialect '") + key + "' not found");
638   }
639   return dialect;
640 }
641 
642 //------------------------------------------------------------------------------
643 // PyLocation
644 //------------------------------------------------------------------------------
645 
646 py::object PyLocation::getCapsule() {
647   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
648 }
649 
650 PyLocation PyLocation::createFromCapsule(py::object capsule) {
651   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
652   if (mlirLocationIsNull(rawLoc))
653     throw py::error_already_set();
654   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
655                     rawLoc);
656 }
657 
658 py::object PyLocation::contextEnter() {
659   return PyThreadContextEntry::pushLocation(*this);
660 }
661 
662 void PyLocation::contextExit(py::object excType, py::object excVal,
663                              py::object excTb) {
664   PyThreadContextEntry::popLocation(*this);
665 }
666 
667 PyLocation &DefaultingPyLocation::resolve() {
668   auto *location = PyThreadContextEntry::getDefaultLocation();
669   if (!location) {
670     throw SetPyError(
671         PyExc_RuntimeError,
672         "An MLIR function requires a Location but none was provided in the "
673         "call or from the surrounding environment. Either pass to the function "
674         "with a 'loc=' argument or establish a default using 'with loc:'");
675   }
676   return *location;
677 }
678 
679 //------------------------------------------------------------------------------
680 // PyModule
681 //------------------------------------------------------------------------------
682 
683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
684     : BaseContextObject(std::move(contextRef)), module(module) {}
685 
686 PyModule::~PyModule() {
687   py::gil_scoped_acquire acquire;
688   auto &liveModules = getContext()->liveModules;
689   assert(liveModules.count(module.ptr) == 1 &&
690          "destroying module not in live map");
691   liveModules.erase(module.ptr);
692   mlirModuleDestroy(module);
693 }
694 
695 PyModuleRef PyModule::forModule(MlirModule module) {
696   MlirContext context = mlirModuleGetContext(module);
697   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
698 
699   py::gil_scoped_acquire acquire;
700   auto &liveModules = contextRef->liveModules;
701   auto it = liveModules.find(module.ptr);
702   if (it == liveModules.end()) {
703     // Create.
704     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
705     // Note that the default return value policy on cast is automatic_reference,
706     // which does not take ownership (delete will not be called).
707     // Just be explicit.
708     py::object pyRef =
709         py::cast(unownedModule, py::return_value_policy::take_ownership);
710     unownedModule->handle = pyRef;
711     liveModules[module.ptr] =
712         std::make_pair(unownedModule->handle, unownedModule);
713     return PyModuleRef(unownedModule, std::move(pyRef));
714   }
715   // Use existing.
716   PyModule *existing = it->second.second;
717   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
718   return PyModuleRef(existing, std::move(pyRef));
719 }
720 
721 py::object PyModule::createFromCapsule(py::object capsule) {
722   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
723   if (mlirModuleIsNull(rawModule))
724     throw py::error_already_set();
725   return forModule(rawModule).releaseObject();
726 }
727 
728 py::object PyModule::getCapsule() {
729   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
730 }
731 
732 //------------------------------------------------------------------------------
733 // PyOperation
734 //------------------------------------------------------------------------------
735 
736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
737     : BaseContextObject(std::move(contextRef)), operation(operation) {}
738 
739 PyOperation::~PyOperation() {
740   auto &liveOperations = getContext()->liveOperations;
741   assert(liveOperations.count(operation.ptr) == 1 &&
742          "destroying operation not in live map");
743   liveOperations.erase(operation.ptr);
744   if (!isAttached()) {
745     mlirOperationDestroy(operation);
746   }
747 }
748 
749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
750                                            MlirOperation operation,
751                                            py::object parentKeepAlive) {
752   auto &liveOperations = contextRef->liveOperations;
753   // Create.
754   PyOperation *unownedOperation =
755       new PyOperation(std::move(contextRef), operation);
756   // Note that the default return value policy on cast is automatic_reference,
757   // which does not take ownership (delete will not be called).
758   // Just be explicit.
759   py::object pyRef =
760       py::cast(unownedOperation, py::return_value_policy::take_ownership);
761   unownedOperation->handle = pyRef;
762   if (parentKeepAlive) {
763     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
764   }
765   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
766   return PyOperationRef(unownedOperation, std::move(pyRef));
767 }
768 
769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
770                                          MlirOperation operation,
771                                          py::object parentKeepAlive) {
772   auto &liveOperations = contextRef->liveOperations;
773   auto it = liveOperations.find(operation.ptr);
774   if (it == liveOperations.end()) {
775     // Create.
776     return createInstance(std::move(contextRef), operation,
777                           std::move(parentKeepAlive));
778   }
779   // Use existing.
780   PyOperation *existing = it->second.second;
781   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
782   return PyOperationRef(existing, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
786                                            MlirOperation operation,
787                                            py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   assert(liveOperations.count(operation.ptr) == 0 &&
790          "cannot create detached operation that already exists");
791   (void)liveOperations;
792 
793   PyOperationRef created = createInstance(std::move(contextRef), operation,
794                                           std::move(parentKeepAlive));
795   created->attached = false;
796   return created;
797 }
798 
799 void PyOperation::checkValid() const {
800   if (!valid) {
801     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
802   }
803 }
804 
805 void PyOperationBase::print(py::object fileObject, bool binary,
806                             llvm::Optional<int64_t> largeElementsLimit,
807                             bool enableDebugInfo, bool prettyDebugInfo,
808                             bool printGenericOpForm, bool useLocalScope) {
809   PyOperation &operation = getOperation();
810   operation.checkValid();
811   if (fileObject.is_none())
812     fileObject = py::module::import("sys").attr("stdout");
813 
814   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
815     fileObject.attr("write")("// Verification failed, printing generic form\n");
816     printGenericOpForm = true;
817   }
818 
819   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
820   if (largeElementsLimit)
821     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
822   if (enableDebugInfo)
823     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
824   if (printGenericOpForm)
825     mlirOpPrintingFlagsPrintGenericOpForm(flags);
826 
827   PyFileAccumulator accum(fileObject, binary);
828   py::gil_scoped_release();
829   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
830                               accum.getUserData());
831   mlirOpPrintingFlagsDestroy(flags);
832 }
833 
834 py::object PyOperationBase::getAsm(bool binary,
835                                    llvm::Optional<int64_t> largeElementsLimit,
836                                    bool enableDebugInfo, bool prettyDebugInfo,
837                                    bool printGenericOpForm,
838                                    bool useLocalScope) {
839   py::object fileObject;
840   if (binary) {
841     fileObject = py::module::import("io").attr("BytesIO")();
842   } else {
843     fileObject = py::module::import("io").attr("StringIO")();
844   }
845   print(fileObject, /*binary=*/binary,
846         /*largeElementsLimit=*/largeElementsLimit,
847         /*enableDebugInfo=*/enableDebugInfo,
848         /*prettyDebugInfo=*/prettyDebugInfo,
849         /*printGenericOpForm=*/printGenericOpForm,
850         /*useLocalScope=*/useLocalScope);
851 
852   return fileObject.attr("getvalue")();
853 }
854 
855 PyOperationRef PyOperation::getParentOperation() {
856   if (!isAttached())
857     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
858   MlirOperation operation = mlirOperationGetParentOperation(get());
859   if (mlirOperationIsNull(operation))
860     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
861   return PyOperation::forOperation(getContext(), operation);
862 }
863 
864 PyBlock PyOperation::getBlock() {
865   PyOperationRef parentOperation = getParentOperation();
866   MlirBlock block = mlirOperationGetBlock(get());
867   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
868   return PyBlock{std::move(parentOperation), block};
869 }
870 
871 py::object PyOperation::getCapsule() {
872   return py::reinterpret_steal<py::object>(mlirPythonOperationToCapsule(get()));
873 }
874 
875 py::object PyOperation::createFromCapsule(py::object capsule) {
876   MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr());
877   if (mlirOperationIsNull(rawOperation))
878     throw py::error_already_set();
879   MlirContext rawCtxt = mlirOperationGetContext(rawOperation);
880   return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation)
881       .releaseObject();
882 }
883 
884 py::object PyOperation::create(
885     std::string name, llvm::Optional<std::vector<PyType *>> results,
886     llvm::Optional<std::vector<PyValue *>> operands,
887     llvm::Optional<py::dict> attributes,
888     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
889     DefaultingPyLocation location, py::object maybeIp) {
890   llvm::SmallVector<MlirValue, 4> mlirOperands;
891   llvm::SmallVector<MlirType, 4> mlirResults;
892   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
893   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
894 
895   // General parameter validation.
896   if (regions < 0)
897     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
898 
899   // Unpack/validate operands.
900   if (operands) {
901     mlirOperands.reserve(operands->size());
902     for (PyValue *operand : *operands) {
903       if (!operand)
904         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
905       mlirOperands.push_back(operand->get());
906     }
907   }
908 
909   // Unpack/validate results.
910   if (results) {
911     mlirResults.reserve(results->size());
912     for (PyType *result : *results) {
913       // TODO: Verify result type originate from the same context.
914       if (!result)
915         throw SetPyError(PyExc_ValueError, "result type cannot be None");
916       mlirResults.push_back(*result);
917     }
918   }
919   // Unpack/validate attributes.
920   if (attributes) {
921     mlirAttributes.reserve(attributes->size());
922     for (auto &it : *attributes) {
923       std::string key;
924       try {
925         key = it.first.cast<std::string>();
926       } catch (py::cast_error &err) {
927         std::string msg = "Invalid attribute key (not a string) when "
928                           "attempting to create the operation \"" +
929                           name + "\" (" + err.what() + ")";
930         throw py::cast_error(msg);
931       }
932       try {
933         auto &attribute = it.second.cast<PyAttribute &>();
934         // TODO: Verify attribute originates from the same context.
935         mlirAttributes.emplace_back(std::move(key), attribute);
936       } catch (py::reference_cast_error &) {
937         // This exception seems thrown when the value is "None".
938         std::string msg =
939             "Found an invalid (`None`?) attribute value for the key \"" + key +
940             "\" when attempting to create the operation \"" + name + "\"";
941         throw py::cast_error(msg);
942       } catch (py::cast_error &err) {
943         std::string msg = "Invalid attribute value for the key \"" + key +
944                           "\" when attempting to create the operation \"" +
945                           name + "\" (" + err.what() + ")";
946         throw py::cast_error(msg);
947       }
948     }
949   }
950   // Unpack/validate successors.
951   if (successors) {
952     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
953     mlirSuccessors.reserve(successors->size());
954     for (auto *successor : *successors) {
955       // TODO: Verify successor originate from the same context.
956       if (!successor)
957         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
958       mlirSuccessors.push_back(successor->get());
959     }
960   }
961 
962   // Apply unpacked/validated to the operation state. Beyond this
963   // point, exceptions cannot be thrown or else the state will leak.
964   MlirOperationState state =
965       mlirOperationStateGet(toMlirStringRef(name), location);
966   if (!mlirOperands.empty())
967     mlirOperationStateAddOperands(&state, mlirOperands.size(),
968                                   mlirOperands.data());
969   if (!mlirResults.empty())
970     mlirOperationStateAddResults(&state, mlirResults.size(),
971                                  mlirResults.data());
972   if (!mlirAttributes.empty()) {
973     // Note that the attribute names directly reference bytes in
974     // mlirAttributes, so that vector must not be changed from here
975     // on.
976     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
977     mlirNamedAttributes.reserve(mlirAttributes.size());
978     for (auto &it : mlirAttributes)
979       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
980           mlirIdentifierGet(mlirAttributeGetContext(it.second),
981                             toMlirStringRef(it.first)),
982           it.second));
983     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
984                                     mlirNamedAttributes.data());
985   }
986   if (!mlirSuccessors.empty())
987     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
988                                     mlirSuccessors.data());
989   if (regions) {
990     llvm::SmallVector<MlirRegion, 4> mlirRegions;
991     mlirRegions.resize(regions);
992     for (int i = 0; i < regions; ++i)
993       mlirRegions[i] = mlirRegionCreate();
994     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
995                                       mlirRegions.data());
996   }
997 
998   // Construct the operation.
999   MlirOperation operation = mlirOperationCreate(&state);
1000   PyOperationRef created =
1001       PyOperation::createDetached(location->getContext(), operation);
1002 
1003   // InsertPoint active?
1004   if (!maybeIp.is(py::cast(false))) {
1005     PyInsertionPoint *ip;
1006     if (maybeIp.is_none()) {
1007       ip = PyThreadContextEntry::getDefaultInsertionPoint();
1008     } else {
1009       ip = py::cast<PyInsertionPoint *>(maybeIp);
1010     }
1011     if (ip)
1012       ip->insert(*created.get());
1013   }
1014 
1015   return created->createOpView();
1016 }
1017 
1018 py::object PyOperation::createOpView() {
1019   MlirIdentifier ident = mlirOperationGetName(get());
1020   MlirStringRef identStr = mlirIdentifierStr(ident);
1021   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1022       StringRef(identStr.data, identStr.length));
1023   if (opViewClass)
1024     return (*opViewClass)(getRef().getObject());
1025   return py::cast(PyOpView(getRef().getObject()));
1026 }
1027 
1028 //------------------------------------------------------------------------------
1029 // PyOpView
1030 //------------------------------------------------------------------------------
1031 
1032 py::object
1033 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1034                        py::list operandList,
1035                        llvm::Optional<py::dict> attributes,
1036                        llvm::Optional<std::vector<PyBlock *>> successors,
1037                        llvm::Optional<int> regions,
1038                        DefaultingPyLocation location, py::object maybeIp) {
1039   PyMlirContextRef context = location->getContext();
1040   // Class level operation construction metadata.
1041   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1042   // Operand and result segment specs are either none, which does no
1043   // variadic unpacking, or a list of ints with segment sizes, where each
1044   // element is either a positive number (typically 1 for a scalar) or -1 to
1045   // indicate that it is derived from the length of the same-indexed operand
1046   // or result (implying that it is a list at that position).
1047   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1048   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1049 
1050   std::vector<uint32_t> operandSegmentLengths;
1051   std::vector<uint32_t> resultSegmentLengths;
1052 
1053   // Validate/determine region count.
1054   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1055   int opMinRegionCount = std::get<0>(opRegionSpec);
1056   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1057   if (!regions) {
1058     regions = opMinRegionCount;
1059   }
1060   if (*regions < opMinRegionCount) {
1061     throw py::value_error(
1062         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1063          llvm::Twine(opMinRegionCount) +
1064          " regions but was built with regions=" + llvm::Twine(*regions))
1065             .str());
1066   }
1067   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1068     throw py::value_error(
1069         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1070          llvm::Twine(opMinRegionCount) +
1071          " regions but was built with regions=" + llvm::Twine(*regions))
1072             .str());
1073   }
1074 
1075   // Unpack results.
1076   std::vector<PyType *> resultTypes;
1077   resultTypes.reserve(resultTypeList.size());
1078   if (resultSegmentSpecObj.is_none()) {
1079     // Non-variadic result unpacking.
1080     for (auto it : llvm::enumerate(resultTypeList)) {
1081       try {
1082         resultTypes.push_back(py::cast<PyType *>(it.value()));
1083         if (!resultTypes.back())
1084           throw py::cast_error();
1085       } catch (py::cast_error &err) {
1086         throw py::value_error((llvm::Twine("Result ") +
1087                                llvm::Twine(it.index()) + " of operation \"" +
1088                                name + "\" must be a Type (" + err.what() + ")")
1089                                   .str());
1090       }
1091     }
1092   } else {
1093     // Sized result unpacking.
1094     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1095     if (resultSegmentSpec.size() != resultTypeList.size()) {
1096       throw py::value_error((llvm::Twine("Operation \"") + name +
1097                              "\" requires " +
1098                              llvm::Twine(resultSegmentSpec.size()) +
1099                              "result segments but was provided " +
1100                              llvm::Twine(resultTypeList.size()))
1101                                 .str());
1102     }
1103     resultSegmentLengths.reserve(resultTypeList.size());
1104     for (auto it :
1105          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1106       int segmentSpec = std::get<1>(it.value());
1107       if (segmentSpec == 1 || segmentSpec == 0) {
1108         // Unpack unary element.
1109         try {
1110           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1111           if (resultType) {
1112             resultTypes.push_back(resultType);
1113             resultSegmentLengths.push_back(1);
1114           } else if (segmentSpec == 0) {
1115             // Allowed to be optional.
1116             resultSegmentLengths.push_back(0);
1117           } else {
1118             throw py::cast_error("was None and result is not optional");
1119           }
1120         } catch (py::cast_error &err) {
1121           throw py::value_error((llvm::Twine("Result ") +
1122                                  llvm::Twine(it.index()) + " of operation \"" +
1123                                  name + "\" must be a Type (" + err.what() +
1124                                  ")")
1125                                     .str());
1126         }
1127       } else if (segmentSpec == -1) {
1128         // Unpack sequence by appending.
1129         try {
1130           if (std::get<0>(it.value()).is_none()) {
1131             // Treat it as an empty list.
1132             resultSegmentLengths.push_back(0);
1133           } else {
1134             // Unpack the list.
1135             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1136             for (py::object segmentItem : segment) {
1137               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1138               if (!resultTypes.back()) {
1139                 throw py::cast_error("contained a None item");
1140               }
1141             }
1142             resultSegmentLengths.push_back(segment.size());
1143           }
1144         } catch (std::exception &err) {
1145           // NOTE: Sloppy to be using a catch-all here, but there are at least
1146           // three different unrelated exceptions that can be thrown in the
1147           // above "casts". Just keep the scope above small and catch them all.
1148           throw py::value_error((llvm::Twine("Result ") +
1149                                  llvm::Twine(it.index()) + " of operation \"" +
1150                                  name + "\" must be a Sequence of Types (" +
1151                                  err.what() + ")")
1152                                     .str());
1153         }
1154       } else {
1155         throw py::value_error("Unexpected segment spec");
1156       }
1157     }
1158   }
1159 
1160   // Unpack operands.
1161   std::vector<PyValue *> operands;
1162   operands.reserve(operands.size());
1163   if (operandSegmentSpecObj.is_none()) {
1164     // Non-sized operand unpacking.
1165     for (auto it : llvm::enumerate(operandList)) {
1166       try {
1167         operands.push_back(py::cast<PyValue *>(it.value()));
1168         if (!operands.back())
1169           throw py::cast_error();
1170       } catch (py::cast_error &err) {
1171         throw py::value_error((llvm::Twine("Operand ") +
1172                                llvm::Twine(it.index()) + " of operation \"" +
1173                                name + "\" must be a Value (" + err.what() + ")")
1174                                   .str());
1175       }
1176     }
1177   } else {
1178     // Sized operand unpacking.
1179     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1180     if (operandSegmentSpec.size() != operandList.size()) {
1181       throw py::value_error((llvm::Twine("Operation \"") + name +
1182                              "\" requires " +
1183                              llvm::Twine(operandSegmentSpec.size()) +
1184                              "operand segments but was provided " +
1185                              llvm::Twine(operandList.size()))
1186                                 .str());
1187     }
1188     operandSegmentLengths.reserve(operandList.size());
1189     for (auto it :
1190          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1191       int segmentSpec = std::get<1>(it.value());
1192       if (segmentSpec == 1 || segmentSpec == 0) {
1193         // Unpack unary element.
1194         try {
1195           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1196           if (operandValue) {
1197             operands.push_back(operandValue);
1198             operandSegmentLengths.push_back(1);
1199           } else if (segmentSpec == 0) {
1200             // Allowed to be optional.
1201             operandSegmentLengths.push_back(0);
1202           } else {
1203             throw py::cast_error("was None and operand is not optional");
1204           }
1205         } catch (py::cast_error &err) {
1206           throw py::value_error((llvm::Twine("Operand ") +
1207                                  llvm::Twine(it.index()) + " of operation \"" +
1208                                  name + "\" must be a Value (" + err.what() +
1209                                  ")")
1210                                     .str());
1211         }
1212       } else if (segmentSpec == -1) {
1213         // Unpack sequence by appending.
1214         try {
1215           if (std::get<0>(it.value()).is_none()) {
1216             // Treat it as an empty list.
1217             operandSegmentLengths.push_back(0);
1218           } else {
1219             // Unpack the list.
1220             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1221             for (py::object segmentItem : segment) {
1222               operands.push_back(py::cast<PyValue *>(segmentItem));
1223               if (!operands.back()) {
1224                 throw py::cast_error("contained a None item");
1225               }
1226             }
1227             operandSegmentLengths.push_back(segment.size());
1228           }
1229         } catch (std::exception &err) {
1230           // NOTE: Sloppy to be using a catch-all here, but there are at least
1231           // three different unrelated exceptions that can be thrown in the
1232           // above "casts". Just keep the scope above small and catch them all.
1233           throw py::value_error((llvm::Twine("Operand ") +
1234                                  llvm::Twine(it.index()) + " of operation \"" +
1235                                  name + "\" must be a Sequence of Values (" +
1236                                  err.what() + ")")
1237                                     .str());
1238         }
1239       } else {
1240         throw py::value_error("Unexpected segment spec");
1241       }
1242     }
1243   }
1244 
1245   // Merge operand/result segment lengths into attributes if needed.
1246   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1247     // Dup.
1248     if (attributes) {
1249       attributes = py::dict(*attributes);
1250     } else {
1251       attributes = py::dict();
1252     }
1253     if (attributes->contains("result_segment_sizes") ||
1254         attributes->contains("operand_segment_sizes")) {
1255       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1256                             "'operand_segment_sizes' attribute is unsupported. "
1257                             "Use Operation.create for such low-level access.");
1258     }
1259 
1260     // Add result_segment_sizes attribute.
1261     if (!resultSegmentLengths.empty()) {
1262       int64_t size = resultSegmentLengths.size();
1263       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1264           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1265           resultSegmentLengths.size(), resultSegmentLengths.data());
1266       (*attributes)["result_segment_sizes"] =
1267           PyAttribute(context, segmentLengthAttr);
1268     }
1269 
1270     // Add operand_segment_sizes attribute.
1271     if (!operandSegmentLengths.empty()) {
1272       int64_t size = operandSegmentLengths.size();
1273       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
1274           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
1275           operandSegmentLengths.size(), operandSegmentLengths.data());
1276       (*attributes)["operand_segment_sizes"] =
1277           PyAttribute(context, segmentLengthAttr);
1278     }
1279   }
1280 
1281   // Delegate to create.
1282   return PyOperation::create(std::move(name),
1283                              /*results=*/std::move(resultTypes),
1284                              /*operands=*/std::move(operands),
1285                              /*attributes=*/std::move(attributes),
1286                              /*successors=*/std::move(successors),
1287                              /*regions=*/*regions, location, maybeIp);
1288 }
1289 
1290 PyOpView::PyOpView(py::object operationObject)
1291     // Casting through the PyOperationBase base-class and then back to the
1292     // Operation lets us accept any PyOperationBase subclass.
1293     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1294       operationObject(operation.getRef().getObject()) {}
1295 
1296 py::object PyOpView::createRawSubclass(py::object userClass) {
1297   // This is... a little gross. The typical pattern is to have a pure python
1298   // class that extends OpView like:
1299   //   class AddFOp(_cext.ir.OpView):
1300   //     def __init__(self, loc, lhs, rhs):
1301   //       operation = loc.context.create_operation(
1302   //           "addf", lhs, rhs, results=[lhs.type])
1303   //       super().__init__(operation)
1304   //
1305   // I.e. The goal of the user facing type is to provide a nice constructor
1306   // that has complete freedom for the op under construction. This is at odds
1307   // with our other desire to sometimes create this object by just passing an
1308   // operation (to initialize the base class). We could do *arg and **kwargs
1309   // munging to try to make it work, but instead, we synthesize a new class
1310   // on the fly which extends this user class (AddFOp in this example) and
1311   // *give it* the base class's __init__ method, thus bypassing the
1312   // intermediate subclass's __init__ method entirely. While slightly,
1313   // underhanded, this is safe/legal because the type hierarchy has not changed
1314   // (we just added a new leaf) and we aren't mucking around with __new__.
1315   // Typically, this new class will be stored on the original as "_Raw" and will
1316   // be used for casts and other things that need a variant of the class that
1317   // is initialized purely from an operation.
1318   py::object parentMetaclass =
1319       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1320   py::dict attributes;
1321   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1322   // now.
1323   //   auto opViewType = py::type::of<PyOpView>();
1324   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1325   attributes["__init__"] = opViewType.attr("__init__");
1326   py::str origName = userClass.attr("__name__");
1327   py::str newName = py::str("_") + origName;
1328   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1329 }
1330 
1331 //------------------------------------------------------------------------------
1332 // PyInsertionPoint.
1333 //------------------------------------------------------------------------------
1334 
1335 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1336 
1337 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1338     : refOperation(beforeOperationBase.getOperation().getRef()),
1339       block((*refOperation)->getBlock()) {}
1340 
1341 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1342   PyOperation &operation = operationBase.getOperation();
1343   if (operation.isAttached())
1344     throw SetPyError(PyExc_ValueError,
1345                      "Attempt to insert operation that is already attached");
1346   block.getParentOperation()->checkValid();
1347   MlirOperation beforeOp = {nullptr};
1348   if (refOperation) {
1349     // Insert before operation.
1350     (*refOperation)->checkValid();
1351     beforeOp = (*refOperation)->get();
1352   } else {
1353     // Insert at end (before null) is only valid if the block does not
1354     // already end in a known terminator (violating this will cause assertion
1355     // failures later).
1356     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1357       throw py::index_error("Cannot insert operation at the end of a block "
1358                             "that already has a terminator. Did you mean to "
1359                             "use 'InsertionPoint.at_block_terminator(block)' "
1360                             "versus 'InsertionPoint(block)'?");
1361     }
1362   }
1363   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1364   operation.setAttached();
1365 }
1366 
1367 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1368   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1369   if (mlirOperationIsNull(firstOp)) {
1370     // Just insert at end.
1371     return PyInsertionPoint(block);
1372   }
1373 
1374   // Insert before first op.
1375   PyOperationRef firstOpRef = PyOperation::forOperation(
1376       block.getParentOperation()->getContext(), firstOp);
1377   return PyInsertionPoint{block, std::move(firstOpRef)};
1378 }
1379 
1380 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1381   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1382   if (mlirOperationIsNull(terminator))
1383     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1384   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1385       block.getParentOperation()->getContext(), terminator);
1386   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1387 }
1388 
1389 py::object PyInsertionPoint::contextEnter() {
1390   return PyThreadContextEntry::pushInsertionPoint(*this);
1391 }
1392 
1393 void PyInsertionPoint::contextExit(pybind11::object excType,
1394                                    pybind11::object excVal,
1395                                    pybind11::object excTb) {
1396   PyThreadContextEntry::popInsertionPoint(*this);
1397 }
1398 
1399 //------------------------------------------------------------------------------
1400 // PyAttribute.
1401 //------------------------------------------------------------------------------
1402 
1403 bool PyAttribute::operator==(const PyAttribute &other) {
1404   return mlirAttributeEqual(attr, other.attr);
1405 }
1406 
1407 py::object PyAttribute::getCapsule() {
1408   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1409 }
1410 
1411 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1412   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1413   if (mlirAttributeIsNull(rawAttr))
1414     throw py::error_already_set();
1415   return PyAttribute(
1416       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1417 }
1418 
1419 //------------------------------------------------------------------------------
1420 // PyNamedAttribute.
1421 //------------------------------------------------------------------------------
1422 
1423 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1424     : ownedName(new std::string(std::move(ownedName))) {
1425   namedAttr = mlirNamedAttributeGet(
1426       mlirIdentifierGet(mlirAttributeGetContext(attr),
1427                         toMlirStringRef(*this->ownedName)),
1428       attr);
1429 }
1430 
1431 //------------------------------------------------------------------------------
1432 // PyType.
1433 //------------------------------------------------------------------------------
1434 
1435 bool PyType::operator==(const PyType &other) {
1436   return mlirTypeEqual(type, other.type);
1437 }
1438 
1439 py::object PyType::getCapsule() {
1440   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1441 }
1442 
1443 PyType PyType::createFromCapsule(py::object capsule) {
1444   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1445   if (mlirTypeIsNull(rawType))
1446     throw py::error_already_set();
1447   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1448                 rawType);
1449 }
1450 
1451 //------------------------------------------------------------------------------
1452 // PyValue and subclases.
1453 //------------------------------------------------------------------------------
1454 
1455 namespace {
1456 /// CRTP base class for Python MLIR values that subclass Value and should be
1457 /// castable from it. The value hierarchy is one level deep and is not supposed
1458 /// to accommodate other levels unless core MLIR changes.
1459 template <typename DerivedTy>
1460 class PyConcreteValue : public PyValue {
1461 public:
1462   // Derived classes must define statics for:
1463   //   IsAFunctionTy isaFunction
1464   //   const char *pyClassName
1465   // and redefine bindDerived.
1466   using ClassTy = py::class_<DerivedTy, PyValue>;
1467   using IsAFunctionTy = bool (*)(MlirValue);
1468 
1469   PyConcreteValue() = default;
1470   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1471       : PyValue(operationRef, value) {}
1472   PyConcreteValue(PyValue &orig)
1473       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1474 
1475   /// Attempts to cast the original value to the derived type and throws on
1476   /// type mismatches.
1477   static MlirValue castFrom(PyValue &orig) {
1478     if (!DerivedTy::isaFunction(orig.get())) {
1479       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1480       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1481                                              DerivedTy::pyClassName +
1482                                              " (from " + origRepr + ")");
1483     }
1484     return orig.get();
1485   }
1486 
1487   /// Binds the Python module objects to functions of this class.
1488   static void bind(py::module &m) {
1489     auto cls = ClassTy(m, DerivedTy::pyClassName);
1490     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1491     DerivedTy::bindDerived(cls);
1492   }
1493 
1494   /// Implemented by derived classes to add methods to the Python subclass.
1495   static void bindDerived(ClassTy &m) {}
1496 };
1497 
1498 /// Python wrapper for MlirBlockArgument.
1499 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1500 public:
1501   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1502   static constexpr const char *pyClassName = "BlockArgument";
1503   using PyConcreteValue::PyConcreteValue;
1504 
1505   static void bindDerived(ClassTy &c) {
1506     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1507       return PyBlock(self.getParentOperation(),
1508                      mlirBlockArgumentGetOwner(self.get()));
1509     });
1510     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1511       return mlirBlockArgumentGetArgNumber(self.get());
1512     });
1513     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1514       return mlirBlockArgumentSetType(self.get(), type);
1515     });
1516   }
1517 };
1518 
1519 /// Python wrapper for MlirOpResult.
1520 class PyOpResult : public PyConcreteValue<PyOpResult> {
1521 public:
1522   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1523   static constexpr const char *pyClassName = "OpResult";
1524   using PyConcreteValue::PyConcreteValue;
1525 
1526   static void bindDerived(ClassTy &c) {
1527     c.def_property_readonly("owner", [](PyOpResult &self) {
1528       assert(
1529           mlirOperationEqual(self.getParentOperation()->get(),
1530                              mlirOpResultGetOwner(self.get())) &&
1531           "expected the owner of the value in Python to match that in the IR");
1532       return self.getParentOperation();
1533     });
1534     c.def_property_readonly("result_number", [](PyOpResult &self) {
1535       return mlirOpResultGetResultNumber(self.get());
1536     });
1537   }
1538 };
1539 
1540 /// A list of block arguments. Internally, these are stored as consecutive
1541 /// elements, random access is cheap. The argument list is associated with the
1542 /// operation that contains the block (detached blocks are not allowed in
1543 /// Python bindings) and extends its lifetime.
1544 class PyBlockArgumentList {
1545 public:
1546   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1547       : operation(std::move(operation)), block(block) {}
1548 
1549   /// Returns the length of the block argument list.
1550   intptr_t dunderLen() {
1551     operation->checkValid();
1552     return mlirBlockGetNumArguments(block);
1553   }
1554 
1555   /// Returns `index`-th element of the block argument list.
1556   PyBlockArgument dunderGetItem(intptr_t index) {
1557     if (index < 0 || index >= dunderLen()) {
1558       throw SetPyError(PyExc_IndexError,
1559                        "attempt to access out of bounds region");
1560     }
1561     PyValue value(operation, mlirBlockGetArgument(block, index));
1562     return PyBlockArgument(value);
1563   }
1564 
1565   /// Defines a Python class in the bindings.
1566   static void bind(py::module &m) {
1567     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1568         .def("__len__", &PyBlockArgumentList::dunderLen)
1569         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1570   }
1571 
1572 private:
1573   PyOperationRef operation;
1574   MlirBlock block;
1575 };
1576 
1577 /// A list of operation operands. Internally, these are stored as consecutive
1578 /// elements, random access is cheap. The result list is associated with the
1579 /// operation whose results these are, and extends the lifetime of this
1580 /// operation.
1581 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1582 public:
1583   static constexpr const char *pyClassName = "OpOperandList";
1584 
1585   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1586                   intptr_t length = -1, intptr_t step = 1)
1587       : Sliceable(startIndex,
1588                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1589                                : length,
1590                   step),
1591         operation(operation) {}
1592 
1593   intptr_t getNumElements() {
1594     operation->checkValid();
1595     return mlirOperationGetNumOperands(operation->get());
1596   }
1597 
1598   PyValue getElement(intptr_t pos) {
1599     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1600   }
1601 
1602   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1603     return PyOpOperandList(operation, startIndex, length, step);
1604   }
1605 
1606 private:
1607   PyOperationRef operation;
1608 };
1609 
1610 /// A list of operation results. Internally, these are stored as consecutive
1611 /// elements, random access is cheap. The result list is associated with the
1612 /// operation whose results these are, and extends the lifetime of this
1613 /// operation.
1614 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1615 public:
1616   static constexpr const char *pyClassName = "OpResultList";
1617 
1618   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1619                  intptr_t length = -1, intptr_t step = 1)
1620       : Sliceable(startIndex,
1621                   length == -1 ? mlirOperationGetNumResults(operation->get())
1622                                : length,
1623                   step),
1624         operation(operation) {}
1625 
1626   intptr_t getNumElements() {
1627     operation->checkValid();
1628     return mlirOperationGetNumResults(operation->get());
1629   }
1630 
1631   PyOpResult getElement(intptr_t index) {
1632     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1633     return PyOpResult(value);
1634   }
1635 
1636   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1637     return PyOpResultList(operation, startIndex, length, step);
1638   }
1639 
1640 private:
1641   PyOperationRef operation;
1642 };
1643 
1644 /// A list of operation attributes. Can be indexed by name, producing
1645 /// attributes, or by index, producing named attributes.
1646 class PyOpAttributeMap {
1647 public:
1648   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1649 
1650   PyAttribute dunderGetItemNamed(const std::string &name) {
1651     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1652                                                          toMlirStringRef(name));
1653     if (mlirAttributeIsNull(attr)) {
1654       throw SetPyError(PyExc_KeyError,
1655                        "attempt to access a non-existent attribute");
1656     }
1657     return PyAttribute(operation->getContext(), attr);
1658   }
1659 
1660   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1661     if (index < 0 || index >= dunderLen()) {
1662       throw SetPyError(PyExc_IndexError,
1663                        "attempt to access out of bounds attribute");
1664     }
1665     MlirNamedAttribute namedAttr =
1666         mlirOperationGetAttribute(operation->get(), index);
1667     return PyNamedAttribute(
1668         namedAttr.attribute,
1669         std::string(mlirIdentifierStr(namedAttr.name).data));
1670   }
1671 
1672   void dunderSetItem(const std::string &name, PyAttribute attr) {
1673     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1674                                     attr);
1675   }
1676 
1677   void dunderDelItem(const std::string &name) {
1678     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1679                                                      toMlirStringRef(name));
1680     if (!removed)
1681       throw SetPyError(PyExc_KeyError,
1682                        "attempt to delete a non-existent attribute");
1683   }
1684 
1685   intptr_t dunderLen() {
1686     return mlirOperationGetNumAttributes(operation->get());
1687   }
1688 
1689   bool dunderContains(const std::string &name) {
1690     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1691         operation->get(), toMlirStringRef(name)));
1692   }
1693 
1694   static void bind(py::module &m) {
1695     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1696         .def("__contains__", &PyOpAttributeMap::dunderContains)
1697         .def("__len__", &PyOpAttributeMap::dunderLen)
1698         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1699         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1700         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1701         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1702   }
1703 
1704 private:
1705   PyOperationRef operation;
1706 };
1707 
1708 } // end namespace
1709 
1710 //------------------------------------------------------------------------------
1711 // Populates the core exports of the 'ir' submodule.
1712 //------------------------------------------------------------------------------
1713 
1714 void mlir::python::populateIRCore(py::module &m) {
1715   //----------------------------------------------------------------------------
1716   // Mapping of MlirContext
1717   //----------------------------------------------------------------------------
1718   py::class_<PyMlirContext>(m, "Context")
1719       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1720       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1721       .def("_get_context_again",
1722            [](PyMlirContext &self) {
1723              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1724              return ref.releaseObject();
1725            })
1726       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1727       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1728       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1729                              &PyMlirContext::getCapsule)
1730       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1731       .def("__enter__", &PyMlirContext::contextEnter)
1732       .def("__exit__", &PyMlirContext::contextExit)
1733       .def_property_readonly_static(
1734           "current",
1735           [](py::object & /*class*/) {
1736             auto *context = PyThreadContextEntry::getDefaultContext();
1737             if (!context)
1738               throw SetPyError(PyExc_ValueError, "No current Context");
1739             return context;
1740           },
1741           "Gets the Context bound to the current thread or raises ValueError")
1742       .def_property_readonly(
1743           "dialects",
1744           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1745           "Gets a container for accessing dialects by name")
1746       .def_property_readonly(
1747           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1748           "Alias for 'dialect'")
1749       .def(
1750           "get_dialect_descriptor",
1751           [=](PyMlirContext &self, std::string &name) {
1752             MlirDialect dialect = mlirContextGetOrLoadDialect(
1753                 self.get(), {name.data(), name.size()});
1754             if (mlirDialectIsNull(dialect)) {
1755               throw SetPyError(PyExc_ValueError,
1756                                Twine("Dialect '") + name + "' not found");
1757             }
1758             return PyDialectDescriptor(self.getRef(), dialect);
1759           },
1760           "Gets or loads a dialect by name, returning its descriptor object")
1761       .def_property(
1762           "allow_unregistered_dialects",
1763           [](PyMlirContext &self) -> bool {
1764             return mlirContextGetAllowUnregisteredDialects(self.get());
1765           },
1766           [](PyMlirContext &self, bool value) {
1767             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1768           })
1769       .def("is_registered_operation",
1770            [](PyMlirContext &self, std::string &name) {
1771              return mlirContextIsRegisteredOperation(
1772                  self.get(), MlirStringRef{name.data(), name.size()});
1773            });
1774 
1775   //----------------------------------------------------------------------------
1776   // Mapping of PyDialectDescriptor
1777   //----------------------------------------------------------------------------
1778   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1779       .def_property_readonly("namespace",
1780                              [](PyDialectDescriptor &self) {
1781                                MlirStringRef ns =
1782                                    mlirDialectGetNamespace(self.get());
1783                                return py::str(ns.data, ns.length);
1784                              })
1785       .def("__repr__", [](PyDialectDescriptor &self) {
1786         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1787         std::string repr("<DialectDescriptor ");
1788         repr.append(ns.data, ns.length);
1789         repr.append(">");
1790         return repr;
1791       });
1792 
1793   //----------------------------------------------------------------------------
1794   // Mapping of PyDialects
1795   //----------------------------------------------------------------------------
1796   py::class_<PyDialects>(m, "Dialects")
1797       .def("__getitem__",
1798            [=](PyDialects &self, std::string keyName) {
1799              MlirDialect dialect =
1800                  self.getDialectForKey(keyName, /*attrError=*/false);
1801              py::object descriptor =
1802                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1803              return createCustomDialectWrapper(keyName, std::move(descriptor));
1804            })
1805       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1806         MlirDialect dialect =
1807             self.getDialectForKey(attrName, /*attrError=*/true);
1808         py::object descriptor =
1809             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1810         return createCustomDialectWrapper(attrName, std::move(descriptor));
1811       });
1812 
1813   //----------------------------------------------------------------------------
1814   // Mapping of PyDialect
1815   //----------------------------------------------------------------------------
1816   py::class_<PyDialect>(m, "Dialect")
1817       .def(py::init<py::object>(), "descriptor")
1818       .def_property_readonly(
1819           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1820       .def("__repr__", [](py::object self) {
1821         auto clazz = self.attr("__class__");
1822         return py::str("<Dialect ") +
1823                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1824                clazz.attr("__module__") + py::str(".") +
1825                clazz.attr("__name__") + py::str(")>");
1826       });
1827 
1828   //----------------------------------------------------------------------------
1829   // Mapping of Location
1830   //----------------------------------------------------------------------------
1831   py::class_<PyLocation>(m, "Location")
1832       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1833       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1834       .def("__enter__", &PyLocation::contextEnter)
1835       .def("__exit__", &PyLocation::contextExit)
1836       .def("__eq__",
1837            [](PyLocation &self, PyLocation &other) -> bool {
1838              return mlirLocationEqual(self, other);
1839            })
1840       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1841       .def_property_readonly_static(
1842           "current",
1843           [](py::object & /*class*/) {
1844             auto *loc = PyThreadContextEntry::getDefaultLocation();
1845             if (!loc)
1846               throw SetPyError(PyExc_ValueError, "No current Location");
1847             return loc;
1848           },
1849           "Gets the Location bound to the current thread or raises ValueError")
1850       .def_static(
1851           "unknown",
1852           [](DefaultingPyMlirContext context) {
1853             return PyLocation(context->getRef(),
1854                               mlirLocationUnknownGet(context->get()));
1855           },
1856           py::arg("context") = py::none(),
1857           "Gets a Location representing an unknown location")
1858       .def_static(
1859           "file",
1860           [](std::string filename, int line, int col,
1861              DefaultingPyMlirContext context) {
1862             return PyLocation(
1863                 context->getRef(),
1864                 mlirLocationFileLineColGet(
1865                     context->get(), toMlirStringRef(filename), line, col));
1866           },
1867           py::arg("filename"), py::arg("line"), py::arg("col"),
1868           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1869       .def_property_readonly(
1870           "context",
1871           [](PyLocation &self) { return self.getContext().getObject(); },
1872           "Context that owns the Location")
1873       .def("__repr__", [](PyLocation &self) {
1874         PyPrintAccumulator printAccum;
1875         mlirLocationPrint(self, printAccum.getCallback(),
1876                           printAccum.getUserData());
1877         return printAccum.join();
1878       });
1879 
1880   //----------------------------------------------------------------------------
1881   // Mapping of Module
1882   //----------------------------------------------------------------------------
1883   py::class_<PyModule>(m, "Module")
1884       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1885       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1886       .def_static(
1887           "parse",
1888           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1889             MlirModule module = mlirModuleCreateParse(
1890                 context->get(), toMlirStringRef(moduleAsm));
1891             // TODO: Rework error reporting once diagnostic engine is exposed
1892             // in C API.
1893             if (mlirModuleIsNull(module)) {
1894               throw SetPyError(
1895                   PyExc_ValueError,
1896                   "Unable to parse module assembly (see diagnostics)");
1897             }
1898             return PyModule::forModule(module).releaseObject();
1899           },
1900           py::arg("asm"), py::arg("context") = py::none(),
1901           kModuleParseDocstring)
1902       .def_static(
1903           "create",
1904           [](DefaultingPyLocation loc) {
1905             MlirModule module = mlirModuleCreateEmpty(loc);
1906             return PyModule::forModule(module).releaseObject();
1907           },
1908           py::arg("loc") = py::none(), "Creates an empty module")
1909       .def_property_readonly(
1910           "context",
1911           [](PyModule &self) { return self.getContext().getObject(); },
1912           "Context that created the Module")
1913       .def_property_readonly(
1914           "operation",
1915           [](PyModule &self) {
1916             return PyOperation::forOperation(self.getContext(),
1917                                              mlirModuleGetOperation(self.get()),
1918                                              self.getRef().releaseObject())
1919                 .releaseObject();
1920           },
1921           "Accesses the module as an operation")
1922       .def_property_readonly(
1923           "body",
1924           [](PyModule &self) {
1925             PyOperationRef module_op = PyOperation::forOperation(
1926                 self.getContext(), mlirModuleGetOperation(self.get()),
1927                 self.getRef().releaseObject());
1928             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
1929             return returnBlock;
1930           },
1931           "Return the block for this module")
1932       .def(
1933           "dump",
1934           [](PyModule &self) {
1935             mlirOperationDump(mlirModuleGetOperation(self.get()));
1936           },
1937           kDumpDocstring)
1938       .def(
1939           "__str__",
1940           [](PyModule &self) {
1941             MlirOperation operation = mlirModuleGetOperation(self.get());
1942             PyPrintAccumulator printAccum;
1943             mlirOperationPrint(operation, printAccum.getCallback(),
1944                                printAccum.getUserData());
1945             return printAccum.join();
1946           },
1947           kOperationStrDunderDocstring);
1948 
1949   //----------------------------------------------------------------------------
1950   // Mapping of Operation.
1951   //----------------------------------------------------------------------------
1952   py::class_<PyOperationBase>(m, "_OperationBase")
1953       .def("__eq__",
1954            [](PyOperationBase &self, PyOperationBase &other) {
1955              return &self.getOperation() == &other.getOperation();
1956            })
1957       .def("__eq__",
1958            [](PyOperationBase &self, py::object other) { return false; })
1959       .def_property_readonly("attributes",
1960                              [](PyOperationBase &self) {
1961                                return PyOpAttributeMap(
1962                                    self.getOperation().getRef());
1963                              })
1964       .def_property_readonly("operands",
1965                              [](PyOperationBase &self) {
1966                                return PyOpOperandList(
1967                                    self.getOperation().getRef());
1968                              })
1969       .def_property_readonly("regions",
1970                              [](PyOperationBase &self) {
1971                                return PyRegionList(
1972                                    self.getOperation().getRef());
1973                              })
1974       .def_property_readonly(
1975           "results",
1976           [](PyOperationBase &self) {
1977             return PyOpResultList(self.getOperation().getRef());
1978           },
1979           "Returns the list of Operation results.")
1980       .def_property_readonly(
1981           "result",
1982           [](PyOperationBase &self) {
1983             auto &operation = self.getOperation();
1984             auto numResults = mlirOperationGetNumResults(operation);
1985             if (numResults != 1) {
1986               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1987               throw SetPyError(
1988                   PyExc_ValueError,
1989                   Twine("Cannot call .result on operation ") +
1990                       StringRef(name.data, name.length) + " which has " +
1991                       Twine(numResults) +
1992                       " results (it is only valid for operations with a "
1993                       "single result)");
1994             }
1995             return PyOpResult(operation.getRef(),
1996                               mlirOperationGetResult(operation, 0));
1997           },
1998           "Shortcut to get an op result if it has only one (throws an error "
1999           "otherwise).")
2000       .def("__iter__",
2001            [](PyOperationBase &self) {
2002              return PyRegionIterator(self.getOperation().getRef());
2003            })
2004       .def(
2005           "__str__",
2006           [](PyOperationBase &self) {
2007             return self.getAsm(/*binary=*/false,
2008                                /*largeElementsLimit=*/llvm::None,
2009                                /*enableDebugInfo=*/false,
2010                                /*prettyDebugInfo=*/false,
2011                                /*printGenericOpForm=*/false,
2012                                /*useLocalScope=*/false);
2013           },
2014           "Returns the assembly form of the operation.")
2015       .def("print", &PyOperationBase::print,
2016            // Careful: Lots of arguments must match up with print method.
2017            py::arg("file") = py::none(), py::arg("binary") = false,
2018            py::arg("large_elements_limit") = py::none(),
2019            py::arg("enable_debug_info") = false,
2020            py::arg("pretty_debug_info") = false,
2021            py::arg("print_generic_op_form") = false,
2022            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2023       .def("get_asm", &PyOperationBase::getAsm,
2024            // Careful: Lots of arguments must match up with get_asm method.
2025            py::arg("binary") = false,
2026            py::arg("large_elements_limit") = py::none(),
2027            py::arg("enable_debug_info") = false,
2028            py::arg("pretty_debug_info") = false,
2029            py::arg("print_generic_op_form") = false,
2030            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2031       .def(
2032           "verify",
2033           [](PyOperationBase &self) {
2034             return mlirOperationVerify(self.getOperation());
2035           },
2036           "Verify the operation and return true if it passes, false if it "
2037           "fails.");
2038 
2039   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2040       .def_static("create", &PyOperation::create, py::arg("name"),
2041                   py::arg("results") = py::none(),
2042                   py::arg("operands") = py::none(),
2043                   py::arg("attributes") = py::none(),
2044                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2045                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2046                   kOperationCreateDocstring)
2047       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2048                              &PyOperation::getCapsule)
2049       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule)
2050       .def_property_readonly("name",
2051                              [](PyOperation &self) {
2052                                MlirOperation operation = self.get();
2053                                MlirStringRef name = mlirIdentifierStr(
2054                                    mlirOperationGetName(operation));
2055                                return py::str(name.data, name.length);
2056                              })
2057       .def_property_readonly(
2058           "context",
2059           [](PyOperation &self) { return self.getContext().getObject(); },
2060           "Context that owns the Operation")
2061       .def_property_readonly("opview", &PyOperation::createOpView);
2062 
2063   auto opViewClass =
2064       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2065           .def(py::init<py::object>())
2066           .def_property_readonly("operation", &PyOpView::getOperationObject)
2067           .def_property_readonly(
2068               "context",
2069               [](PyOpView &self) {
2070                 return self.getOperation().getContext().getObject();
2071               },
2072               "Context that owns the Operation")
2073           .def("__str__", [](PyOpView &self) {
2074             return py::str(self.getOperationObject());
2075           });
2076   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2077   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2078   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2079   opViewClass.attr("build_generic") = classmethod(
2080       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2081       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2082       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2083       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2084       "Builds a specific, generated OpView based on class level attributes.");
2085 
2086   //----------------------------------------------------------------------------
2087   // Mapping of PyRegion.
2088   //----------------------------------------------------------------------------
2089   py::class_<PyRegion>(m, "Region")
2090       .def_property_readonly(
2091           "blocks",
2092           [](PyRegion &self) {
2093             return PyBlockList(self.getParentOperation(), self.get());
2094           },
2095           "Returns a forward-optimized sequence of blocks.")
2096       .def(
2097           "__iter__",
2098           [](PyRegion &self) {
2099             self.checkValid();
2100             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2101             return PyBlockIterator(self.getParentOperation(), firstBlock);
2102           },
2103           "Iterates over blocks in the region.")
2104       .def("__eq__",
2105            [](PyRegion &self, PyRegion &other) {
2106              return self.get().ptr == other.get().ptr;
2107            })
2108       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2109 
2110   //----------------------------------------------------------------------------
2111   // Mapping of PyBlock.
2112   //----------------------------------------------------------------------------
2113   py::class_<PyBlock>(m, "Block")
2114       .def_property_readonly(
2115           "arguments",
2116           [](PyBlock &self) {
2117             return PyBlockArgumentList(self.getParentOperation(), self.get());
2118           },
2119           "Returns a list of block arguments.")
2120       .def_property_readonly(
2121           "operations",
2122           [](PyBlock &self) {
2123             return PyOperationList(self.getParentOperation(), self.get());
2124           },
2125           "Returns a forward-optimized sequence of operations.")
2126       .def(
2127           "__iter__",
2128           [](PyBlock &self) {
2129             self.checkValid();
2130             MlirOperation firstOperation =
2131                 mlirBlockGetFirstOperation(self.get());
2132             return PyOperationIterator(self.getParentOperation(),
2133                                        firstOperation);
2134           },
2135           "Iterates over operations in the block.")
2136       .def("__eq__",
2137            [](PyBlock &self, PyBlock &other) {
2138              return self.get().ptr == other.get().ptr;
2139            })
2140       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2141       .def(
2142           "__str__",
2143           [](PyBlock &self) {
2144             self.checkValid();
2145             PyPrintAccumulator printAccum;
2146             mlirBlockPrint(self.get(), printAccum.getCallback(),
2147                            printAccum.getUserData());
2148             return printAccum.join();
2149           },
2150           "Returns the assembly form of the block.");
2151 
2152   //----------------------------------------------------------------------------
2153   // Mapping of PyInsertionPoint.
2154   //----------------------------------------------------------------------------
2155 
2156   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2157       .def(py::init<PyBlock &>(), py::arg("block"),
2158            "Inserts after the last operation but still inside the block.")
2159       .def("__enter__", &PyInsertionPoint::contextEnter)
2160       .def("__exit__", &PyInsertionPoint::contextExit)
2161       .def_property_readonly_static(
2162           "current",
2163           [](py::object & /*class*/) {
2164             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2165             if (!ip)
2166               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2167             return ip;
2168           },
2169           "Gets the InsertionPoint bound to the current thread or raises "
2170           "ValueError if none has been set")
2171       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2172            "Inserts before a referenced operation.")
2173       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2174                   py::arg("block"), "Inserts at the beginning of the block.")
2175       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2176                   py::arg("block"), "Inserts before the block terminator.")
2177       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2178            "Inserts an operation.");
2179 
2180   //----------------------------------------------------------------------------
2181   // Mapping of PyAttribute.
2182   //----------------------------------------------------------------------------
2183   py::class_<PyAttribute>(m, "Attribute")
2184       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2185                              &PyAttribute::getCapsule)
2186       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2187       .def_static(
2188           "parse",
2189           [](std::string attrSpec, DefaultingPyMlirContext context) {
2190             MlirAttribute type = mlirAttributeParseGet(
2191                 context->get(), toMlirStringRef(attrSpec));
2192             // TODO: Rework error reporting once diagnostic engine is exposed
2193             // in C API.
2194             if (mlirAttributeIsNull(type)) {
2195               throw SetPyError(PyExc_ValueError,
2196                                Twine("Unable to parse attribute: '") +
2197                                    attrSpec + "'");
2198             }
2199             return PyAttribute(context->getRef(), type);
2200           },
2201           py::arg("asm"), py::arg("context") = py::none(),
2202           "Parses an attribute from an assembly form")
2203       .def_property_readonly(
2204           "context",
2205           [](PyAttribute &self) { return self.getContext().getObject(); },
2206           "Context that owns the Attribute")
2207       .def_property_readonly("type",
2208                              [](PyAttribute &self) {
2209                                return PyType(self.getContext()->getRef(),
2210                                              mlirAttributeGetType(self));
2211                              })
2212       .def(
2213           "get_named",
2214           [](PyAttribute &self, std::string name) {
2215             return PyNamedAttribute(self, std::move(name));
2216           },
2217           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2218       .def("__eq__",
2219            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2220       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2221       .def(
2222           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2223           kDumpDocstring)
2224       .def(
2225           "__str__",
2226           [](PyAttribute &self) {
2227             PyPrintAccumulator printAccum;
2228             mlirAttributePrint(self, printAccum.getCallback(),
2229                                printAccum.getUserData());
2230             return printAccum.join();
2231           },
2232           "Returns the assembly form of the Attribute.")
2233       .def("__repr__", [](PyAttribute &self) {
2234         // Generally, assembly formats are not printed for __repr__ because
2235         // this can cause exceptionally long debug output and exceptions.
2236         // However, attribute values are generally considered useful and are
2237         // printed. This may need to be re-evaluated if debug dumps end up
2238         // being excessive.
2239         PyPrintAccumulator printAccum;
2240         printAccum.parts.append("Attribute(");
2241         mlirAttributePrint(self, printAccum.getCallback(),
2242                            printAccum.getUserData());
2243         printAccum.parts.append(")");
2244         return printAccum.join();
2245       });
2246 
2247   //----------------------------------------------------------------------------
2248   // Mapping of PyNamedAttribute
2249   //----------------------------------------------------------------------------
2250   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2251       .def("__repr__",
2252            [](PyNamedAttribute &self) {
2253              PyPrintAccumulator printAccum;
2254              printAccum.parts.append("NamedAttribute(");
2255              printAccum.parts.append(
2256                  mlirIdentifierStr(self.namedAttr.name).data);
2257              printAccum.parts.append("=");
2258              mlirAttributePrint(self.namedAttr.attribute,
2259                                 printAccum.getCallback(),
2260                                 printAccum.getUserData());
2261              printAccum.parts.append(")");
2262              return printAccum.join();
2263            })
2264       .def_property_readonly(
2265           "name",
2266           [](PyNamedAttribute &self) {
2267             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2268                            mlirIdentifierStr(self.namedAttr.name).length);
2269           },
2270           "The name of the NamedAttribute binding")
2271       .def_property_readonly(
2272           "attr",
2273           [](PyNamedAttribute &self) {
2274             // TODO: When named attribute is removed/refactored, also remove
2275             // this constructor (it does an inefficient table lookup).
2276             auto contextRef = PyMlirContext::forContext(
2277                 mlirAttributeGetContext(self.namedAttr.attribute));
2278             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2279           },
2280           py::keep_alive<0, 1>(),
2281           "The underlying generic attribute of the NamedAttribute binding");
2282 
2283   //----------------------------------------------------------------------------
2284   // Mapping of PyType.
2285   //----------------------------------------------------------------------------
2286   py::class_<PyType>(m, "Type")
2287       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2288       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2289       .def_static(
2290           "parse",
2291           [](std::string typeSpec, DefaultingPyMlirContext context) {
2292             MlirType type =
2293                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2294             // TODO: Rework error reporting once diagnostic engine is exposed
2295             // in C API.
2296             if (mlirTypeIsNull(type)) {
2297               throw SetPyError(PyExc_ValueError,
2298                                Twine("Unable to parse type: '") + typeSpec +
2299                                    "'");
2300             }
2301             return PyType(context->getRef(), type);
2302           },
2303           py::arg("asm"), py::arg("context") = py::none(),
2304           kContextParseTypeDocstring)
2305       .def_property_readonly(
2306           "context", [](PyType &self) { return self.getContext().getObject(); },
2307           "Context that owns the Type")
2308       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2309       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2310       .def(
2311           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2312       .def(
2313           "__str__",
2314           [](PyType &self) {
2315             PyPrintAccumulator printAccum;
2316             mlirTypePrint(self, printAccum.getCallback(),
2317                           printAccum.getUserData());
2318             return printAccum.join();
2319           },
2320           "Returns the assembly form of the type.")
2321       .def("__repr__", [](PyType &self) {
2322         // Generally, assembly formats are not printed for __repr__ because
2323         // this can cause exceptionally long debug output and exceptions.
2324         // However, types are an exception as they typically have compact
2325         // assembly forms and printing them is useful.
2326         PyPrintAccumulator printAccum;
2327         printAccum.parts.append("Type(");
2328         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2329         printAccum.parts.append(")");
2330         return printAccum.join();
2331       });
2332 
2333   //----------------------------------------------------------------------------
2334   // Mapping of Value.
2335   //----------------------------------------------------------------------------
2336   py::class_<PyValue>(m, "Value")
2337       .def_property_readonly(
2338           "context",
2339           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2340           "Context in which the value lives.")
2341       .def(
2342           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2343           kDumpDocstring)
2344       .def("__eq__",
2345            [](PyValue &self, PyValue &other) {
2346              return self.get().ptr == other.get().ptr;
2347            })
2348       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2349       .def(
2350           "__str__",
2351           [](PyValue &self) {
2352             PyPrintAccumulator printAccum;
2353             printAccum.parts.append("Value(");
2354             mlirValuePrint(self.get(), printAccum.getCallback(),
2355                            printAccum.getUserData());
2356             printAccum.parts.append(")");
2357             return printAccum.join();
2358           },
2359           kValueDunderStrDocstring)
2360       .def_property_readonly("type", [](PyValue &self) {
2361         return PyType(self.getParentOperation()->getContext(),
2362                       mlirValueGetType(self.get()));
2363       });
2364   PyBlockArgument::bind(m);
2365   PyOpResult::bind(m);
2366 
2367   // Container bindings.
2368   PyBlockArgumentList::bind(m);
2369   PyBlockIterator::bind(m);
2370   PyBlockList::bind(m);
2371   PyOperationIterator::bind(m);
2372   PyOperationList::bind(m);
2373   PyOpAttributeMap::bind(m);
2374   PyOpOperandList::bind(m);
2375   PyOpResultList::bind(m);
2376   PyRegionIterator::bind(m);
2377   PyRegionList::bind(m);
2378 }
2379