1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "IRModule.h"
10 
11 #include "Globals.h"
12 #include "PybindUtils.h"
13 
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Registration.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <pybind11/stl.h>
20 
21 namespace py = pybind11;
22 using namespace mlir;
23 using namespace mlir::python;
24 
25 using llvm::SmallVector;
26 using llvm::StringRef;
27 using llvm::Twine;
28 
29 //------------------------------------------------------------------------------
30 // Docstrings (trivial, non-duplicated docstrings are included inline).
31 //------------------------------------------------------------------------------
32 
33 static const char kContextParseTypeDocstring[] =
34     R"(Parses the assembly form of a type.
35 
36 Returns a Type object or raises a ValueError if the type cannot be parsed.
37 
38 See also: https://mlir.llvm.org/docs/LangRef/#type-system
39 )";
40 
41 static const char kContextGetFileLocationDocstring[] =
42     R"(Gets a Location representing a file, line and column)";
43 
44 static const char kModuleParseDocstring[] =
45     R"(Parses a module's assembly format from a string.
46 
47 Returns a new MlirModule or raises a ValueError if the parsing fails.
48 
49 See also: https://mlir.llvm.org/docs/LangRef/
50 )";
51 
52 static const char kOperationCreateDocstring[] =
53     R"(Creates a new operation.
54 
55 Args:
56   name: Operation name (e.g. "dialect.operation").
57   results: Sequence of Type representing op result types.
58   attributes: Dict of str:Attribute.
59   successors: List of Block for the operation's successors.
60   regions: Number of regions to create.
61   location: A Location object (defaults to resolve from context manager).
62   ip: An InsertionPoint (defaults to resolve from context manager or set to
63     False to disable insertion, even with an insertion point set in the
64     context manager).
65 Returns:
66   A new "detached" Operation object. Detached operations can be added
67   to blocks, which causes them to become "attached."
68 )";
69 
70 static const char kOperationPrintDocstring[] =
71     R"(Prints the assembly form of the operation to a file like object.
72 
73 Args:
74   file: The file like object to write to. Defaults to sys.stdout.
75   binary: Whether to write bytes (True) or str (False). Defaults to False.
76   large_elements_limit: Whether to elide elements attributes above this
77     number of elements. Defaults to None (no limit).
78   enable_debug_info: Whether to print debug/location information. Defaults
79     to False.
80   pretty_debug_info: Whether to format debug information for easier reading
81     by a human (warning: the result is unparseable).
82   print_generic_op_form: Whether to print the generic assembly forms of all
83     ops. Defaults to False.
84   use_local_Scope: Whether to print in a way that is more optimized for
85     multi-threaded access but may not be consistent with how the overall
86     module prints.
87 )";
88 
89 static const char kOperationGetAsmDocstring[] =
90     R"(Gets the assembly form of the operation with all options available.
91 
92 Args:
93   binary: Whether to return a bytes (True) or str (False) object. Defaults to
94     False.
95   ... others ...: See the print() method for common keyword arguments for
96     configuring the printout.
97 Returns:
98   Either a bytes or str object, depending on the setting of the 'binary'
99   argument.
100 )";
101 
102 static const char kOperationStrDunderDocstring[] =
103     R"(Gets the assembly form of the operation with default options.
104 
105 If more advanced control over the assembly formatting or I/O options is needed,
106 use the dedicated print or get_asm method, which supports keyword arguments to
107 customize behavior.
108 )";
109 
110 static const char kDumpDocstring[] =
111     R"(Dumps a debug representation of the object to stderr.)";
112 
113 static const char kAppendBlockDocstring[] =
114     R"(Appends a new block, with argument types as positional args.
115 
116 Returns:
117   The created block.
118 )";
119 
120 static const char kValueDunderStrDocstring[] =
121     R"(Returns the string form of the value.
122 
123 If the value is a block argument, this is the assembly form of its type and the
124 position in the argument list. If the value is an operation result, this is
125 equivalent to printing the operation that produced it.
126 )";
127 
128 //------------------------------------------------------------------------------
129 // Utilities.
130 //------------------------------------------------------------------------------
131 
132 // Helper for creating an @classmethod.
133 template <class Func, typename... Args>
134 py::object classmethod(Func f, Args... args) {
135   py::object cf = py::cpp_function(f, args...);
136   return py::reinterpret_borrow<py::object>((PyClassMethod_New(cf.ptr())));
137 }
138 
139 static py::object
140 createCustomDialectWrapper(const std::string &dialectNamespace,
141                            py::object dialectDescriptor) {
142   auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
143   if (!dialectClass) {
144     // Use the base class.
145     return py::cast(PyDialect(std::move(dialectDescriptor)));
146   }
147 
148   // Create the custom implementation.
149   return (*dialectClass)(std::move(dialectDescriptor));
150 }
151 
152 static MlirStringRef toMlirStringRef(const std::string &s) {
153   return mlirStringRefCreate(s.data(), s.size());
154 }
155 
156 //------------------------------------------------------------------------------
157 // Collections.
158 //------------------------------------------------------------------------------
159 
160 namespace {
161 
162 class PyRegionIterator {
163 public:
164   PyRegionIterator(PyOperationRef operation)
165       : operation(std::move(operation)) {}
166 
167   PyRegionIterator &dunderIter() { return *this; }
168 
169   PyRegion dunderNext() {
170     operation->checkValid();
171     if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
172       throw py::stop_iteration();
173     }
174     MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
175     return PyRegion(operation, region);
176   }
177 
178   static void bind(py::module &m) {
179     py::class_<PyRegionIterator>(m, "RegionIterator")
180         .def("__iter__", &PyRegionIterator::dunderIter)
181         .def("__next__", &PyRegionIterator::dunderNext);
182   }
183 
184 private:
185   PyOperationRef operation;
186   int nextIndex = 0;
187 };
188 
189 /// Regions of an op are fixed length and indexed numerically so are represented
190 /// with a sequence-like container.
191 class PyRegionList {
192 public:
193   PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
194 
195   intptr_t dunderLen() {
196     operation->checkValid();
197     return mlirOperationGetNumRegions(operation->get());
198   }
199 
200   PyRegion dunderGetItem(intptr_t index) {
201     // dunderLen checks validity.
202     if (index < 0 || index >= dunderLen()) {
203       throw SetPyError(PyExc_IndexError,
204                        "attempt to access out of bounds region");
205     }
206     MlirRegion region = mlirOperationGetRegion(operation->get(), index);
207     return PyRegion(operation, region);
208   }
209 
210   static void bind(py::module &m) {
211     py::class_<PyRegionList>(m, "RegionSequence")
212         .def("__len__", &PyRegionList::dunderLen)
213         .def("__getitem__", &PyRegionList::dunderGetItem);
214   }
215 
216 private:
217   PyOperationRef operation;
218 };
219 
220 class PyBlockIterator {
221 public:
222   PyBlockIterator(PyOperationRef operation, MlirBlock next)
223       : operation(std::move(operation)), next(next) {}
224 
225   PyBlockIterator &dunderIter() { return *this; }
226 
227   PyBlock dunderNext() {
228     operation->checkValid();
229     if (mlirBlockIsNull(next)) {
230       throw py::stop_iteration();
231     }
232 
233     PyBlock returnBlock(operation, next);
234     next = mlirBlockGetNextInRegion(next);
235     return returnBlock;
236   }
237 
238   static void bind(py::module &m) {
239     py::class_<PyBlockIterator>(m, "BlockIterator")
240         .def("__iter__", &PyBlockIterator::dunderIter)
241         .def("__next__", &PyBlockIterator::dunderNext);
242   }
243 
244 private:
245   PyOperationRef operation;
246   MlirBlock next;
247 };
248 
249 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
250 /// we present them as a more full-featured list-like container but optimize
251 /// it for forward iteration. Blocks are always owned by a region.
252 class PyBlockList {
253 public:
254   PyBlockList(PyOperationRef operation, MlirRegion region)
255       : operation(std::move(operation)), region(region) {}
256 
257   PyBlockIterator dunderIter() {
258     operation->checkValid();
259     return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
260   }
261 
262   intptr_t dunderLen() {
263     operation->checkValid();
264     intptr_t count = 0;
265     MlirBlock block = mlirRegionGetFirstBlock(region);
266     while (!mlirBlockIsNull(block)) {
267       count += 1;
268       block = mlirBlockGetNextInRegion(block);
269     }
270     return count;
271   }
272 
273   PyBlock dunderGetItem(intptr_t index) {
274     operation->checkValid();
275     if (index < 0) {
276       throw SetPyError(PyExc_IndexError,
277                        "attempt to access out of bounds block");
278     }
279     MlirBlock block = mlirRegionGetFirstBlock(region);
280     while (!mlirBlockIsNull(block)) {
281       if (index == 0) {
282         return PyBlock(operation, block);
283       }
284       block = mlirBlockGetNextInRegion(block);
285       index -= 1;
286     }
287     throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
288   }
289 
290   PyBlock appendBlock(py::args pyArgTypes) {
291     operation->checkValid();
292     llvm::SmallVector<MlirType, 4> argTypes;
293     argTypes.reserve(pyArgTypes.size());
294     for (auto &pyArg : pyArgTypes) {
295       argTypes.push_back(pyArg.cast<PyType &>());
296     }
297 
298     MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
299     mlirRegionAppendOwnedBlock(region, block);
300     return PyBlock(operation, block);
301   }
302 
303   static void bind(py::module &m) {
304     py::class_<PyBlockList>(m, "BlockList")
305         .def("__getitem__", &PyBlockList::dunderGetItem)
306         .def("__iter__", &PyBlockList::dunderIter)
307         .def("__len__", &PyBlockList::dunderLen)
308         .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
309   }
310 
311 private:
312   PyOperationRef operation;
313   MlirRegion region;
314 };
315 
316 class PyOperationIterator {
317 public:
318   PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
319       : parentOperation(std::move(parentOperation)), next(next) {}
320 
321   PyOperationIterator &dunderIter() { return *this; }
322 
323   py::object dunderNext() {
324     parentOperation->checkValid();
325     if (mlirOperationIsNull(next)) {
326       throw py::stop_iteration();
327     }
328 
329     PyOperationRef returnOperation =
330         PyOperation::forOperation(parentOperation->getContext(), next);
331     next = mlirOperationGetNextInBlock(next);
332     return returnOperation->createOpView();
333   }
334 
335   static void bind(py::module &m) {
336     py::class_<PyOperationIterator>(m, "OperationIterator")
337         .def("__iter__", &PyOperationIterator::dunderIter)
338         .def("__next__", &PyOperationIterator::dunderNext);
339   }
340 
341 private:
342   PyOperationRef parentOperation;
343   MlirOperation next;
344 };
345 
346 /// Operations are exposed by the C-API as a forward-only linked list. In
347 /// Python, we present them as a more full-featured list-like container but
348 /// optimize it for forward iteration. Iterable operations are always owned
349 /// by a block.
350 class PyOperationList {
351 public:
352   PyOperationList(PyOperationRef parentOperation, MlirBlock block)
353       : parentOperation(std::move(parentOperation)), block(block) {}
354 
355   PyOperationIterator dunderIter() {
356     parentOperation->checkValid();
357     return PyOperationIterator(parentOperation,
358                                mlirBlockGetFirstOperation(block));
359   }
360 
361   intptr_t dunderLen() {
362     parentOperation->checkValid();
363     intptr_t count = 0;
364     MlirOperation childOp = mlirBlockGetFirstOperation(block);
365     while (!mlirOperationIsNull(childOp)) {
366       count += 1;
367       childOp = mlirOperationGetNextInBlock(childOp);
368     }
369     return count;
370   }
371 
372   py::object dunderGetItem(intptr_t index) {
373     parentOperation->checkValid();
374     if (index < 0) {
375       throw SetPyError(PyExc_IndexError,
376                        "attempt to access out of bounds operation");
377     }
378     MlirOperation childOp = mlirBlockGetFirstOperation(block);
379     while (!mlirOperationIsNull(childOp)) {
380       if (index == 0) {
381         return PyOperation::forOperation(parentOperation->getContext(), childOp)
382             ->createOpView();
383       }
384       childOp = mlirOperationGetNextInBlock(childOp);
385       index -= 1;
386     }
387     throw SetPyError(PyExc_IndexError,
388                      "attempt to access out of bounds operation");
389   }
390 
391   static void bind(py::module &m) {
392     py::class_<PyOperationList>(m, "OperationList")
393         .def("__getitem__", &PyOperationList::dunderGetItem)
394         .def("__iter__", &PyOperationList::dunderIter)
395         .def("__len__", &PyOperationList::dunderLen);
396   }
397 
398 private:
399   PyOperationRef parentOperation;
400   MlirBlock block;
401 };
402 
403 } // namespace
404 
405 //------------------------------------------------------------------------------
406 // PyMlirContext
407 //------------------------------------------------------------------------------
408 
409 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
410   py::gil_scoped_acquire acquire;
411   auto &liveContexts = getLiveContexts();
412   liveContexts[context.ptr] = this;
413 }
414 
415 PyMlirContext::~PyMlirContext() {
416   // Note that the only public way to construct an instance is via the
417   // forContext method, which always puts the associated handle into
418   // liveContexts.
419   py::gil_scoped_acquire acquire;
420   getLiveContexts().erase(context.ptr);
421   mlirContextDestroy(context);
422 }
423 
424 py::object PyMlirContext::getCapsule() {
425   return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
426 }
427 
428 py::object PyMlirContext::createFromCapsule(py::object capsule) {
429   MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
430   if (mlirContextIsNull(rawContext))
431     throw py::error_already_set();
432   return forContext(rawContext).releaseObject();
433 }
434 
435 PyMlirContext *PyMlirContext::createNewContextForInit() {
436   MlirContext context = mlirContextCreate();
437   mlirRegisterAllDialects(context);
438   return new PyMlirContext(context);
439 }
440 
441 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
442   py::gil_scoped_acquire acquire;
443   auto &liveContexts = getLiveContexts();
444   auto it = liveContexts.find(context.ptr);
445   if (it == liveContexts.end()) {
446     // Create.
447     PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
448     py::object pyRef = py::cast(unownedContextWrapper);
449     assert(pyRef && "cast to py::object failed");
450     liveContexts[context.ptr] = unownedContextWrapper;
451     return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
452   }
453   // Use existing.
454   py::object pyRef = py::cast(it->second);
455   return PyMlirContextRef(it->second, std::move(pyRef));
456 }
457 
458 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
459   static LiveContextMap liveContexts;
460   return liveContexts;
461 }
462 
463 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
464 
465 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
466 
467 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
468 
469 pybind11::object PyMlirContext::contextEnter() {
470   return PyThreadContextEntry::pushContext(*this);
471 }
472 
473 void PyMlirContext::contextExit(pybind11::object excType,
474                                 pybind11::object excVal,
475                                 pybind11::object excTb) {
476   PyThreadContextEntry::popContext(*this);
477 }
478 
479 PyMlirContext &DefaultingPyMlirContext::resolve() {
480   PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
481   if (!context) {
482     throw SetPyError(
483         PyExc_RuntimeError,
484         "An MLIR function requires a Context but none was provided in the call "
485         "or from the surrounding environment. Either pass to the function with "
486         "a 'context=' argument or establish a default using 'with Context():'");
487   }
488   return *context;
489 }
490 
491 //------------------------------------------------------------------------------
492 // PyThreadContextEntry management
493 //------------------------------------------------------------------------------
494 
495 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
496   static thread_local std::vector<PyThreadContextEntry> stack;
497   return stack;
498 }
499 
500 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
501   auto &stack = getStack();
502   if (stack.empty())
503     return nullptr;
504   return &stack.back();
505 }
506 
507 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
508                                 py::object insertionPoint,
509                                 py::object location) {
510   auto &stack = getStack();
511   stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
512                      std::move(location));
513   // If the new stack has more than one entry and the context of the new top
514   // entry matches the previous, copy the insertionPoint and location from the
515   // previous entry if missing from the new top entry.
516   if (stack.size() > 1) {
517     auto &prev = *(stack.rbegin() + 1);
518     auto &current = stack.back();
519     if (current.context.is(prev.context)) {
520       // Default non-context objects from the previous entry.
521       if (!current.insertionPoint)
522         current.insertionPoint = prev.insertionPoint;
523       if (!current.location)
524         current.location = prev.location;
525     }
526   }
527 }
528 
529 PyMlirContext *PyThreadContextEntry::getContext() {
530   if (!context)
531     return nullptr;
532   return py::cast<PyMlirContext *>(context);
533 }
534 
535 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
536   if (!insertionPoint)
537     return nullptr;
538   return py::cast<PyInsertionPoint *>(insertionPoint);
539 }
540 
541 PyLocation *PyThreadContextEntry::getLocation() {
542   if (!location)
543     return nullptr;
544   return py::cast<PyLocation *>(location);
545 }
546 
547 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
548   auto *tos = getTopOfStack();
549   return tos ? tos->getContext() : nullptr;
550 }
551 
552 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
553   auto *tos = getTopOfStack();
554   return tos ? tos->getInsertionPoint() : nullptr;
555 }
556 
557 PyLocation *PyThreadContextEntry::getDefaultLocation() {
558   auto *tos = getTopOfStack();
559   return tos ? tos->getLocation() : nullptr;
560 }
561 
562 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
563   py::object contextObj = py::cast(context);
564   push(FrameKind::Context, /*context=*/contextObj,
565        /*insertionPoint=*/py::object(),
566        /*location=*/py::object());
567   return contextObj;
568 }
569 
570 void PyThreadContextEntry::popContext(PyMlirContext &context) {
571   auto &stack = getStack();
572   if (stack.empty())
573     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
574   auto &tos = stack.back();
575   if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
576     throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
577   stack.pop_back();
578 }
579 
580 py::object
581 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
582   py::object contextObj =
583       insertionPoint.getBlock().getParentOperation()->getContext().getObject();
584   py::object insertionPointObj = py::cast(insertionPoint);
585   push(FrameKind::InsertionPoint,
586        /*context=*/contextObj,
587        /*insertionPoint=*/insertionPointObj,
588        /*location=*/py::object());
589   return insertionPointObj;
590 }
591 
592 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
593   auto &stack = getStack();
594   if (stack.empty())
595     throw SetPyError(PyExc_RuntimeError,
596                      "Unbalanced InsertionPoint enter/exit");
597   auto &tos = stack.back();
598   if (tos.frameKind != FrameKind::InsertionPoint &&
599       tos.getInsertionPoint() != &insertionPoint)
600     throw SetPyError(PyExc_RuntimeError,
601                      "Unbalanced InsertionPoint enter/exit");
602   stack.pop_back();
603 }
604 
605 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
606   py::object contextObj = location.getContext().getObject();
607   py::object locationObj = py::cast(location);
608   push(FrameKind::Location, /*context=*/contextObj,
609        /*insertionPoint=*/py::object(),
610        /*location=*/locationObj);
611   return locationObj;
612 }
613 
614 void PyThreadContextEntry::popLocation(PyLocation &location) {
615   auto &stack = getStack();
616   if (stack.empty())
617     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
618   auto &tos = stack.back();
619   if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
620     throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
621   stack.pop_back();
622 }
623 
624 //------------------------------------------------------------------------------
625 // PyDialect, PyDialectDescriptor, PyDialects
626 //------------------------------------------------------------------------------
627 
628 MlirDialect PyDialects::getDialectForKey(const std::string &key,
629                                          bool attrError) {
630   // If the "std" dialect was asked for, substitute the empty namespace :(
631   static const std::string emptyKey;
632   const std::string *canonKey = key == "std" ? &emptyKey : &key;
633   MlirDialect dialect = mlirContextGetOrLoadDialect(
634       getContext()->get(), {canonKey->data(), canonKey->size()});
635   if (mlirDialectIsNull(dialect)) {
636     throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
637                      Twine("Dialect '") + key + "' not found");
638   }
639   return dialect;
640 }
641 
642 //------------------------------------------------------------------------------
643 // PyLocation
644 //------------------------------------------------------------------------------
645 
646 py::object PyLocation::getCapsule() {
647   return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
648 }
649 
650 PyLocation PyLocation::createFromCapsule(py::object capsule) {
651   MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
652   if (mlirLocationIsNull(rawLoc))
653     throw py::error_already_set();
654   return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
655                     rawLoc);
656 }
657 
658 py::object PyLocation::contextEnter() {
659   return PyThreadContextEntry::pushLocation(*this);
660 }
661 
662 void PyLocation::contextExit(py::object excType, py::object excVal,
663                              py::object excTb) {
664   PyThreadContextEntry::popLocation(*this);
665 }
666 
667 PyLocation &DefaultingPyLocation::resolve() {
668   auto *location = PyThreadContextEntry::getDefaultLocation();
669   if (!location) {
670     throw SetPyError(
671         PyExc_RuntimeError,
672         "An MLIR function requires a Location but none was provided in the "
673         "call or from the surrounding environment. Either pass to the function "
674         "with a 'loc=' argument or establish a default using 'with loc:'");
675   }
676   return *location;
677 }
678 
679 //------------------------------------------------------------------------------
680 // PyModule
681 //------------------------------------------------------------------------------
682 
683 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
684     : BaseContextObject(std::move(contextRef)), module(module) {}
685 
686 PyModule::~PyModule() {
687   py::gil_scoped_acquire acquire;
688   auto &liveModules = getContext()->liveModules;
689   assert(liveModules.count(module.ptr) == 1 &&
690          "destroying module not in live map");
691   liveModules.erase(module.ptr);
692   mlirModuleDestroy(module);
693 }
694 
695 PyModuleRef PyModule::forModule(MlirModule module) {
696   MlirContext context = mlirModuleGetContext(module);
697   PyMlirContextRef contextRef = PyMlirContext::forContext(context);
698 
699   py::gil_scoped_acquire acquire;
700   auto &liveModules = contextRef->liveModules;
701   auto it = liveModules.find(module.ptr);
702   if (it == liveModules.end()) {
703     // Create.
704     PyModule *unownedModule = new PyModule(std::move(contextRef), module);
705     // Note that the default return value policy on cast is automatic_reference,
706     // which does not take ownership (delete will not be called).
707     // Just be explicit.
708     py::object pyRef =
709         py::cast(unownedModule, py::return_value_policy::take_ownership);
710     unownedModule->handle = pyRef;
711     liveModules[module.ptr] =
712         std::make_pair(unownedModule->handle, unownedModule);
713     return PyModuleRef(unownedModule, std::move(pyRef));
714   }
715   // Use existing.
716   PyModule *existing = it->second.second;
717   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
718   return PyModuleRef(existing, std::move(pyRef));
719 }
720 
721 py::object PyModule::createFromCapsule(py::object capsule) {
722   MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
723   if (mlirModuleIsNull(rawModule))
724     throw py::error_already_set();
725   return forModule(rawModule).releaseObject();
726 }
727 
728 py::object PyModule::getCapsule() {
729   return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
730 }
731 
732 //------------------------------------------------------------------------------
733 // PyOperation
734 //------------------------------------------------------------------------------
735 
736 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
737     : BaseContextObject(std::move(contextRef)), operation(operation) {}
738 
739 PyOperation::~PyOperation() {
740   auto &liveOperations = getContext()->liveOperations;
741   assert(liveOperations.count(operation.ptr) == 1 &&
742          "destroying operation not in live map");
743   liveOperations.erase(operation.ptr);
744   if (!isAttached()) {
745     mlirOperationDestroy(operation);
746   }
747 }
748 
749 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
750                                            MlirOperation operation,
751                                            py::object parentKeepAlive) {
752   auto &liveOperations = contextRef->liveOperations;
753   // Create.
754   PyOperation *unownedOperation =
755       new PyOperation(std::move(contextRef), operation);
756   // Note that the default return value policy on cast is automatic_reference,
757   // which does not take ownership (delete will not be called).
758   // Just be explicit.
759   py::object pyRef =
760       py::cast(unownedOperation, py::return_value_policy::take_ownership);
761   unownedOperation->handle = pyRef;
762   if (parentKeepAlive) {
763     unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
764   }
765   liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
766   return PyOperationRef(unownedOperation, std::move(pyRef));
767 }
768 
769 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
770                                          MlirOperation operation,
771                                          py::object parentKeepAlive) {
772   auto &liveOperations = contextRef->liveOperations;
773   auto it = liveOperations.find(operation.ptr);
774   if (it == liveOperations.end()) {
775     // Create.
776     return createInstance(std::move(contextRef), operation,
777                           std::move(parentKeepAlive));
778   }
779   // Use existing.
780   PyOperation *existing = it->second.second;
781   py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
782   return PyOperationRef(existing, std::move(pyRef));
783 }
784 
785 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
786                                            MlirOperation operation,
787                                            py::object parentKeepAlive) {
788   auto &liveOperations = contextRef->liveOperations;
789   assert(liveOperations.count(operation.ptr) == 0 &&
790          "cannot create detached operation that already exists");
791   (void)liveOperations;
792 
793   PyOperationRef created = createInstance(std::move(contextRef), operation,
794                                           std::move(parentKeepAlive));
795   created->attached = false;
796   return created;
797 }
798 
799 void PyOperation::checkValid() const {
800   if (!valid) {
801     throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
802   }
803 }
804 
805 void PyOperationBase::print(py::object fileObject, bool binary,
806                             llvm::Optional<int64_t> largeElementsLimit,
807                             bool enableDebugInfo, bool prettyDebugInfo,
808                             bool printGenericOpForm, bool useLocalScope) {
809   PyOperation &operation = getOperation();
810   operation.checkValid();
811   if (fileObject.is_none())
812     fileObject = py::module::import("sys").attr("stdout");
813 
814   if (!printGenericOpForm && !mlirOperationVerify(operation)) {
815     fileObject.attr("write")("// Verification failed, printing generic form\n");
816     printGenericOpForm = true;
817   }
818 
819   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
820   if (largeElementsLimit)
821     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
822   if (enableDebugInfo)
823     mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
824   if (printGenericOpForm)
825     mlirOpPrintingFlagsPrintGenericOpForm(flags);
826 
827   PyFileAccumulator accum(fileObject, binary);
828   py::gil_scoped_release();
829   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
830                               accum.getUserData());
831   mlirOpPrintingFlagsDestroy(flags);
832 }
833 
834 py::object PyOperationBase::getAsm(bool binary,
835                                    llvm::Optional<int64_t> largeElementsLimit,
836                                    bool enableDebugInfo, bool prettyDebugInfo,
837                                    bool printGenericOpForm,
838                                    bool useLocalScope) {
839   py::object fileObject;
840   if (binary) {
841     fileObject = py::module::import("io").attr("BytesIO")();
842   } else {
843     fileObject = py::module::import("io").attr("StringIO")();
844   }
845   print(fileObject, /*binary=*/binary,
846         /*largeElementsLimit=*/largeElementsLimit,
847         /*enableDebugInfo=*/enableDebugInfo,
848         /*prettyDebugInfo=*/prettyDebugInfo,
849         /*printGenericOpForm=*/printGenericOpForm,
850         /*useLocalScope=*/useLocalScope);
851 
852   return fileObject.attr("getvalue")();
853 }
854 
855 PyOperationRef PyOperation::getParentOperation() {
856   if (!isAttached())
857     throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
858   MlirOperation operation = mlirOperationGetParentOperation(get());
859   if (mlirOperationIsNull(operation))
860     throw SetPyError(PyExc_ValueError, "Operation has no parent.");
861   return PyOperation::forOperation(getContext(), operation);
862 }
863 
864 PyBlock PyOperation::getBlock() {
865   PyOperationRef parentOperation = getParentOperation();
866   MlirBlock block = mlirOperationGetBlock(get());
867   assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
868   return PyBlock{std::move(parentOperation), block};
869 }
870 
871 py::object PyOperation::create(
872     std::string name, llvm::Optional<std::vector<PyType *>> results,
873     llvm::Optional<std::vector<PyValue *>> operands,
874     llvm::Optional<py::dict> attributes,
875     llvm::Optional<std::vector<PyBlock *>> successors, int regions,
876     DefaultingPyLocation location, py::object maybeIp) {
877   llvm::SmallVector<MlirValue, 4> mlirOperands;
878   llvm::SmallVector<MlirType, 4> mlirResults;
879   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
880   llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
881 
882   // General parameter validation.
883   if (regions < 0)
884     throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
885 
886   // Unpack/validate operands.
887   if (operands) {
888     mlirOperands.reserve(operands->size());
889     for (PyValue *operand : *operands) {
890       if (!operand)
891         throw SetPyError(PyExc_ValueError, "operand value cannot be None");
892       mlirOperands.push_back(operand->get());
893     }
894   }
895 
896   // Unpack/validate results.
897   if (results) {
898     mlirResults.reserve(results->size());
899     for (PyType *result : *results) {
900       // TODO: Verify result type originate from the same context.
901       if (!result)
902         throw SetPyError(PyExc_ValueError, "result type cannot be None");
903       mlirResults.push_back(*result);
904     }
905   }
906   // Unpack/validate attributes.
907   if (attributes) {
908     mlirAttributes.reserve(attributes->size());
909     for (auto &it : *attributes) {
910       std::string key;
911       try {
912         key = it.first.cast<std::string>();
913       } catch (py::cast_error &err) {
914         std::string msg = "Invalid attribute key (not a string) when "
915                           "attempting to create the operation \"" +
916                           name + "\" (" + err.what() + ")";
917         throw py::cast_error(msg);
918       }
919       try {
920         auto &attribute = it.second.cast<PyAttribute &>();
921         // TODO: Verify attribute originates from the same context.
922         mlirAttributes.emplace_back(std::move(key), attribute);
923       } catch (py::reference_cast_error &) {
924         // This exception seems thrown when the value is "None".
925         std::string msg =
926             "Found an invalid (`None`?) attribute value for the key \"" + key +
927             "\" when attempting to create the operation \"" + name + "\"";
928         throw py::cast_error(msg);
929       } catch (py::cast_error &err) {
930         std::string msg = "Invalid attribute value for the key \"" + key +
931                           "\" when attempting to create the operation \"" +
932                           name + "\" (" + err.what() + ")";
933         throw py::cast_error(msg);
934       }
935     }
936   }
937   // Unpack/validate successors.
938   if (successors) {
939     llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
940     mlirSuccessors.reserve(successors->size());
941     for (auto *successor : *successors) {
942       // TODO: Verify successor originate from the same context.
943       if (!successor)
944         throw SetPyError(PyExc_ValueError, "successor block cannot be None");
945       mlirSuccessors.push_back(successor->get());
946     }
947   }
948 
949   // Apply unpacked/validated to the operation state. Beyond this
950   // point, exceptions cannot be thrown or else the state will leak.
951   MlirOperationState state =
952       mlirOperationStateGet(toMlirStringRef(name), location);
953   if (!mlirOperands.empty())
954     mlirOperationStateAddOperands(&state, mlirOperands.size(),
955                                   mlirOperands.data());
956   if (!mlirResults.empty())
957     mlirOperationStateAddResults(&state, mlirResults.size(),
958                                  mlirResults.data());
959   if (!mlirAttributes.empty()) {
960     // Note that the attribute names directly reference bytes in
961     // mlirAttributes, so that vector must not be changed from here
962     // on.
963     llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
964     mlirNamedAttributes.reserve(mlirAttributes.size());
965     for (auto &it : mlirAttributes)
966       mlirNamedAttributes.push_back(mlirNamedAttributeGet(
967           mlirIdentifierGet(mlirAttributeGetContext(it.second),
968                             toMlirStringRef(it.first)),
969           it.second));
970     mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
971                                     mlirNamedAttributes.data());
972   }
973   if (!mlirSuccessors.empty())
974     mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
975                                     mlirSuccessors.data());
976   if (regions) {
977     llvm::SmallVector<MlirRegion, 4> mlirRegions;
978     mlirRegions.resize(regions);
979     for (int i = 0; i < regions; ++i)
980       mlirRegions[i] = mlirRegionCreate();
981     mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
982                                       mlirRegions.data());
983   }
984 
985   // Construct the operation.
986   MlirOperation operation = mlirOperationCreate(&state);
987   PyOperationRef created =
988       PyOperation::createDetached(location->getContext(), operation);
989 
990   // InsertPoint active?
991   if (!maybeIp.is(py::cast(false))) {
992     PyInsertionPoint *ip;
993     if (maybeIp.is_none()) {
994       ip = PyThreadContextEntry::getDefaultInsertionPoint();
995     } else {
996       ip = py::cast<PyInsertionPoint *>(maybeIp);
997     }
998     if (ip)
999       ip->insert(*created.get());
1000   }
1001 
1002   return created->createOpView();
1003 }
1004 
1005 py::object PyOperation::createOpView() {
1006   MlirIdentifier ident = mlirOperationGetName(get());
1007   MlirStringRef identStr = mlirIdentifierStr(ident);
1008   auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1009       StringRef(identStr.data, identStr.length));
1010   if (opViewClass)
1011     return (*opViewClass)(getRef().getObject());
1012   return py::cast(PyOpView(getRef().getObject()));
1013 }
1014 
1015 //------------------------------------------------------------------------------
1016 // PyOpView
1017 //------------------------------------------------------------------------------
1018 
1019 py::object
1020 PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1021                        py::list operandList,
1022                        llvm::Optional<py::dict> attributes,
1023                        llvm::Optional<std::vector<PyBlock *>> successors,
1024                        llvm::Optional<int> regions,
1025                        DefaultingPyLocation location, py::object maybeIp) {
1026   PyMlirContextRef context = location->getContext();
1027   // Class level operation construction metadata.
1028   std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
1029   // Operand and result segment specs are either none, which does no
1030   // variadic unpacking, or a list of ints with segment sizes, where each
1031   // element is either a positive number (typically 1 for a scalar) or -1 to
1032   // indicate that it is derived from the length of the same-indexed operand
1033   // or result (implying that it is a list at that position).
1034   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
1035   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
1036 
1037   std::vector<uint64_t> operandSegmentLengths;
1038   std::vector<uint64_t> resultSegmentLengths;
1039 
1040   // Validate/determine region count.
1041   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
1042   int opMinRegionCount = std::get<0>(opRegionSpec);
1043   bool opHasNoVariadicRegions = std::get<1>(opRegionSpec);
1044   if (!regions) {
1045     regions = opMinRegionCount;
1046   }
1047   if (*regions < opMinRegionCount) {
1048     throw py::value_error(
1049         (llvm::Twine("Operation \"") + name + "\" requires a minimum of " +
1050          llvm::Twine(opMinRegionCount) +
1051          " regions but was built with regions=" + llvm::Twine(*regions))
1052             .str());
1053   }
1054   if (opHasNoVariadicRegions && *regions > opMinRegionCount) {
1055     throw py::value_error(
1056         (llvm::Twine("Operation \"") + name + "\" requires a maximum of " +
1057          llvm::Twine(opMinRegionCount) +
1058          " regions but was built with regions=" + llvm::Twine(*regions))
1059             .str());
1060   }
1061 
1062   // Unpack results.
1063   std::vector<PyType *> resultTypes;
1064   resultTypes.reserve(resultTypeList.size());
1065   if (resultSegmentSpecObj.is_none()) {
1066     // Non-variadic result unpacking.
1067     for (auto it : llvm::enumerate(resultTypeList)) {
1068       try {
1069         resultTypes.push_back(py::cast<PyType *>(it.value()));
1070         if (!resultTypes.back())
1071           throw py::cast_error();
1072       } catch (py::cast_error &err) {
1073         throw py::value_error((llvm::Twine("Result ") +
1074                                llvm::Twine(it.index()) + " of operation \"" +
1075                                name + "\" must be a Type (" + err.what() + ")")
1076                                   .str());
1077       }
1078     }
1079   } else {
1080     // Sized result unpacking.
1081     auto resultSegmentSpec = py::cast<std::vector<int>>(resultSegmentSpecObj);
1082     if (resultSegmentSpec.size() != resultTypeList.size()) {
1083       throw py::value_error((llvm::Twine("Operation \"") + name +
1084                              "\" requires " +
1085                              llvm::Twine(resultSegmentSpec.size()) +
1086                              "result segments but was provided " +
1087                              llvm::Twine(resultTypeList.size()))
1088                                 .str());
1089     }
1090     resultSegmentLengths.reserve(resultTypeList.size());
1091     for (auto it :
1092          llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) {
1093       int segmentSpec = std::get<1>(it.value());
1094       if (segmentSpec == 1 || segmentSpec == 0) {
1095         // Unpack unary element.
1096         try {
1097           auto resultType = py::cast<PyType *>(std::get<0>(it.value()));
1098           if (resultType) {
1099             resultTypes.push_back(resultType);
1100             resultSegmentLengths.push_back(1);
1101           } else if (segmentSpec == 0) {
1102             // Allowed to be optional.
1103             resultSegmentLengths.push_back(0);
1104           } else {
1105             throw py::cast_error("was None and result is not optional");
1106           }
1107         } catch (py::cast_error &err) {
1108           throw py::value_error((llvm::Twine("Result ") +
1109                                  llvm::Twine(it.index()) + " of operation \"" +
1110                                  name + "\" must be a Type (" + err.what() +
1111                                  ")")
1112                                     .str());
1113         }
1114       } else if (segmentSpec == -1) {
1115         // Unpack sequence by appending.
1116         try {
1117           if (std::get<0>(it.value()).is_none()) {
1118             // Treat it as an empty list.
1119             resultSegmentLengths.push_back(0);
1120           } else {
1121             // Unpack the list.
1122             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1123             for (py::object segmentItem : segment) {
1124               resultTypes.push_back(py::cast<PyType *>(segmentItem));
1125               if (!resultTypes.back()) {
1126                 throw py::cast_error("contained a None item");
1127               }
1128             }
1129             resultSegmentLengths.push_back(segment.size());
1130           }
1131         } catch (std::exception &err) {
1132           // NOTE: Sloppy to be using a catch-all here, but there are at least
1133           // three different unrelated exceptions that can be thrown in the
1134           // above "casts". Just keep the scope above small and catch them all.
1135           throw py::value_error((llvm::Twine("Result ") +
1136                                  llvm::Twine(it.index()) + " of operation \"" +
1137                                  name + "\" must be a Sequence of Types (" +
1138                                  err.what() + ")")
1139                                     .str());
1140         }
1141       } else {
1142         throw py::value_error("Unexpected segment spec");
1143       }
1144     }
1145   }
1146 
1147   // Unpack operands.
1148   std::vector<PyValue *> operands;
1149   operands.reserve(operands.size());
1150   if (operandSegmentSpecObj.is_none()) {
1151     // Non-sized operand unpacking.
1152     for (auto it : llvm::enumerate(operandList)) {
1153       try {
1154         operands.push_back(py::cast<PyValue *>(it.value()));
1155         if (!operands.back())
1156           throw py::cast_error();
1157       } catch (py::cast_error &err) {
1158         throw py::value_error((llvm::Twine("Operand ") +
1159                                llvm::Twine(it.index()) + " of operation \"" +
1160                                name + "\" must be a Value (" + err.what() + ")")
1161                                   .str());
1162       }
1163     }
1164   } else {
1165     // Sized operand unpacking.
1166     auto operandSegmentSpec = py::cast<std::vector<int>>(operandSegmentSpecObj);
1167     if (operandSegmentSpec.size() != operandList.size()) {
1168       throw py::value_error((llvm::Twine("Operation \"") + name +
1169                              "\" requires " +
1170                              llvm::Twine(operandSegmentSpec.size()) +
1171                              "operand segments but was provided " +
1172                              llvm::Twine(operandList.size()))
1173                                 .str());
1174     }
1175     operandSegmentLengths.reserve(operandList.size());
1176     for (auto it :
1177          llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) {
1178       int segmentSpec = std::get<1>(it.value());
1179       if (segmentSpec == 1 || segmentSpec == 0) {
1180         // Unpack unary element.
1181         try {
1182           auto operandValue = py::cast<PyValue *>(std::get<0>(it.value()));
1183           if (operandValue) {
1184             operands.push_back(operandValue);
1185             operandSegmentLengths.push_back(1);
1186           } else if (segmentSpec == 0) {
1187             // Allowed to be optional.
1188             operandSegmentLengths.push_back(0);
1189           } else {
1190             throw py::cast_error("was None and operand is not optional");
1191           }
1192         } catch (py::cast_error &err) {
1193           throw py::value_error((llvm::Twine("Operand ") +
1194                                  llvm::Twine(it.index()) + " of operation \"" +
1195                                  name + "\" must be a Value (" + err.what() +
1196                                  ")")
1197                                     .str());
1198         }
1199       } else if (segmentSpec == -1) {
1200         // Unpack sequence by appending.
1201         try {
1202           if (std::get<0>(it.value()).is_none()) {
1203             // Treat it as an empty list.
1204             operandSegmentLengths.push_back(0);
1205           } else {
1206             // Unpack the list.
1207             auto segment = py::cast<py::sequence>(std::get<0>(it.value()));
1208             for (py::object segmentItem : segment) {
1209               operands.push_back(py::cast<PyValue *>(segmentItem));
1210               if (!operands.back()) {
1211                 throw py::cast_error("contained a None item");
1212               }
1213             }
1214             operandSegmentLengths.push_back(segment.size());
1215           }
1216         } catch (std::exception &err) {
1217           // NOTE: Sloppy to be using a catch-all here, but there are at least
1218           // three different unrelated exceptions that can be thrown in the
1219           // above "casts". Just keep the scope above small and catch them all.
1220           throw py::value_error((llvm::Twine("Operand ") +
1221                                  llvm::Twine(it.index()) + " of operation \"" +
1222                                  name + "\" must be a Sequence of Values (" +
1223                                  err.what() + ")")
1224                                     .str());
1225         }
1226       } else {
1227         throw py::value_error("Unexpected segment spec");
1228       }
1229     }
1230   }
1231 
1232   // Merge operand/result segment lengths into attributes if needed.
1233   if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) {
1234     // Dup.
1235     if (attributes) {
1236       attributes = py::dict(*attributes);
1237     } else {
1238       attributes = py::dict();
1239     }
1240     if (attributes->contains("result_segment_sizes") ||
1241         attributes->contains("operand_segment_sizes")) {
1242       throw py::value_error("Manually setting a 'result_segment_sizes' or "
1243                             "'operand_segment_sizes' attribute is unsupported. "
1244                             "Use Operation.create for such low-level access.");
1245     }
1246 
1247     // Add result_segment_sizes attribute.
1248     if (!resultSegmentLengths.empty()) {
1249       int64_t size = resultSegmentLengths.size();
1250       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1251           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1252           resultSegmentLengths.size(), resultSegmentLengths.data());
1253       (*attributes)["result_segment_sizes"] =
1254           PyAttribute(context, segmentLengthAttr);
1255     }
1256 
1257     // Add operand_segment_sizes attribute.
1258     if (!operandSegmentLengths.empty()) {
1259       int64_t size = operandSegmentLengths.size();
1260       MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
1261           mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
1262           operandSegmentLengths.size(), operandSegmentLengths.data());
1263       (*attributes)["operand_segment_sizes"] =
1264           PyAttribute(context, segmentLengthAttr);
1265     }
1266   }
1267 
1268   // Delegate to create.
1269   return PyOperation::create(std::move(name),
1270                              /*results=*/std::move(resultTypes),
1271                              /*operands=*/std::move(operands),
1272                              /*attributes=*/std::move(attributes),
1273                              /*successors=*/std::move(successors),
1274                              /*regions=*/*regions, location, maybeIp);
1275 }
1276 
1277 PyOpView::PyOpView(py::object operationObject)
1278     // Casting through the PyOperationBase base-class and then back to the
1279     // Operation lets us accept any PyOperationBase subclass.
1280     : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1281       operationObject(operation.getRef().getObject()) {}
1282 
1283 py::object PyOpView::createRawSubclass(py::object userClass) {
1284   // This is... a little gross. The typical pattern is to have a pure python
1285   // class that extends OpView like:
1286   //   class AddFOp(_cext.ir.OpView):
1287   //     def __init__(self, loc, lhs, rhs):
1288   //       operation = loc.context.create_operation(
1289   //           "addf", lhs, rhs, results=[lhs.type])
1290   //       super().__init__(operation)
1291   //
1292   // I.e. The goal of the user facing type is to provide a nice constructor
1293   // that has complete freedom for the op under construction. This is at odds
1294   // with our other desire to sometimes create this object by just passing an
1295   // operation (to initialize the base class). We could do *arg and **kwargs
1296   // munging to try to make it work, but instead, we synthesize a new class
1297   // on the fly which extends this user class (AddFOp in this example) and
1298   // *give it* the base class's __init__ method, thus bypassing the
1299   // intermediate subclass's __init__ method entirely. While slightly,
1300   // underhanded, this is safe/legal because the type hierarchy has not changed
1301   // (we just added a new leaf) and we aren't mucking around with __new__.
1302   // Typically, this new class will be stored on the original as "_Raw" and will
1303   // be used for casts and other things that need a variant of the class that
1304   // is initialized purely from an operation.
1305   py::object parentMetaclass =
1306       py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1307   py::dict attributes;
1308   // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1309   // now.
1310   //   auto opViewType = py::type::of<PyOpView>();
1311   auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1312   attributes["__init__"] = opViewType.attr("__init__");
1313   py::str origName = userClass.attr("__name__");
1314   py::str newName = py::str("_") + origName;
1315   return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1316 }
1317 
1318 //------------------------------------------------------------------------------
1319 // PyInsertionPoint.
1320 //------------------------------------------------------------------------------
1321 
1322 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1323 
1324 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1325     : refOperation(beforeOperationBase.getOperation().getRef()),
1326       block((*refOperation)->getBlock()) {}
1327 
1328 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1329   PyOperation &operation = operationBase.getOperation();
1330   if (operation.isAttached())
1331     throw SetPyError(PyExc_ValueError,
1332                      "Attempt to insert operation that is already attached");
1333   block.getParentOperation()->checkValid();
1334   MlirOperation beforeOp = {nullptr};
1335   if (refOperation) {
1336     // Insert before operation.
1337     (*refOperation)->checkValid();
1338     beforeOp = (*refOperation)->get();
1339   } else {
1340     // Insert at end (before null) is only valid if the block does not
1341     // already end in a known terminator (violating this will cause assertion
1342     // failures later).
1343     if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1344       throw py::index_error("Cannot insert operation at the end of a block "
1345                             "that already has a terminator. Did you mean to "
1346                             "use 'InsertionPoint.at_block_terminator(block)' "
1347                             "versus 'InsertionPoint(block)'?");
1348     }
1349   }
1350   mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1351   operation.setAttached();
1352 }
1353 
1354 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1355   MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1356   if (mlirOperationIsNull(firstOp)) {
1357     // Just insert at end.
1358     return PyInsertionPoint(block);
1359   }
1360 
1361   // Insert before first op.
1362   PyOperationRef firstOpRef = PyOperation::forOperation(
1363       block.getParentOperation()->getContext(), firstOp);
1364   return PyInsertionPoint{block, std::move(firstOpRef)};
1365 }
1366 
1367 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1368   MlirOperation terminator = mlirBlockGetTerminator(block.get());
1369   if (mlirOperationIsNull(terminator))
1370     throw SetPyError(PyExc_ValueError, "Block has no terminator");
1371   PyOperationRef terminatorOpRef = PyOperation::forOperation(
1372       block.getParentOperation()->getContext(), terminator);
1373   return PyInsertionPoint{block, std::move(terminatorOpRef)};
1374 }
1375 
1376 py::object PyInsertionPoint::contextEnter() {
1377   return PyThreadContextEntry::pushInsertionPoint(*this);
1378 }
1379 
1380 void PyInsertionPoint::contextExit(pybind11::object excType,
1381                                    pybind11::object excVal,
1382                                    pybind11::object excTb) {
1383   PyThreadContextEntry::popInsertionPoint(*this);
1384 }
1385 
1386 //------------------------------------------------------------------------------
1387 // PyAttribute.
1388 //------------------------------------------------------------------------------
1389 
1390 bool PyAttribute::operator==(const PyAttribute &other) {
1391   return mlirAttributeEqual(attr, other.attr);
1392 }
1393 
1394 py::object PyAttribute::getCapsule() {
1395   return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1396 }
1397 
1398 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1399   MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1400   if (mlirAttributeIsNull(rawAttr))
1401     throw py::error_already_set();
1402   return PyAttribute(
1403       PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1404 }
1405 
1406 //------------------------------------------------------------------------------
1407 // PyNamedAttribute.
1408 //------------------------------------------------------------------------------
1409 
1410 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1411     : ownedName(new std::string(std::move(ownedName))) {
1412   namedAttr = mlirNamedAttributeGet(
1413       mlirIdentifierGet(mlirAttributeGetContext(attr),
1414                         toMlirStringRef(*this->ownedName)),
1415       attr);
1416 }
1417 
1418 //------------------------------------------------------------------------------
1419 // PyType.
1420 //------------------------------------------------------------------------------
1421 
1422 bool PyType::operator==(const PyType &other) {
1423   return mlirTypeEqual(type, other.type);
1424 }
1425 
1426 py::object PyType::getCapsule() {
1427   return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1428 }
1429 
1430 PyType PyType::createFromCapsule(py::object capsule) {
1431   MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1432   if (mlirTypeIsNull(rawType))
1433     throw py::error_already_set();
1434   return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1435                 rawType);
1436 }
1437 
1438 //------------------------------------------------------------------------------
1439 // PyValue and subclases.
1440 //------------------------------------------------------------------------------
1441 
1442 namespace {
1443 /// CRTP base class for Python MLIR values that subclass Value and should be
1444 /// castable from it. The value hierarchy is one level deep and is not supposed
1445 /// to accommodate other levels unless core MLIR changes.
1446 template <typename DerivedTy>
1447 class PyConcreteValue : public PyValue {
1448 public:
1449   // Derived classes must define statics for:
1450   //   IsAFunctionTy isaFunction
1451   //   const char *pyClassName
1452   // and redefine bindDerived.
1453   using ClassTy = py::class_<DerivedTy, PyValue>;
1454   using IsAFunctionTy = bool (*)(MlirValue);
1455 
1456   PyConcreteValue() = default;
1457   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1458       : PyValue(operationRef, value) {}
1459   PyConcreteValue(PyValue &orig)
1460       : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1461 
1462   /// Attempts to cast the original value to the derived type and throws on
1463   /// type mismatches.
1464   static MlirValue castFrom(PyValue &orig) {
1465     if (!DerivedTy::isaFunction(orig.get())) {
1466       auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1467       throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1468                                              DerivedTy::pyClassName +
1469                                              " (from " + origRepr + ")");
1470     }
1471     return orig.get();
1472   }
1473 
1474   /// Binds the Python module objects to functions of this class.
1475   static void bind(py::module &m) {
1476     auto cls = ClassTy(m, DerivedTy::pyClassName);
1477     cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1478     DerivedTy::bindDerived(cls);
1479   }
1480 
1481   /// Implemented by derived classes to add methods to the Python subclass.
1482   static void bindDerived(ClassTy &m) {}
1483 };
1484 
1485 /// Python wrapper for MlirBlockArgument.
1486 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1487 public:
1488   static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1489   static constexpr const char *pyClassName = "BlockArgument";
1490   using PyConcreteValue::PyConcreteValue;
1491 
1492   static void bindDerived(ClassTy &c) {
1493     c.def_property_readonly("owner", [](PyBlockArgument &self) {
1494       return PyBlock(self.getParentOperation(),
1495                      mlirBlockArgumentGetOwner(self.get()));
1496     });
1497     c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1498       return mlirBlockArgumentGetArgNumber(self.get());
1499     });
1500     c.def("set_type", [](PyBlockArgument &self, PyType type) {
1501       return mlirBlockArgumentSetType(self.get(), type);
1502     });
1503   }
1504 };
1505 
1506 /// Python wrapper for MlirOpResult.
1507 class PyOpResult : public PyConcreteValue<PyOpResult> {
1508 public:
1509   static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1510   static constexpr const char *pyClassName = "OpResult";
1511   using PyConcreteValue::PyConcreteValue;
1512 
1513   static void bindDerived(ClassTy &c) {
1514     c.def_property_readonly("owner", [](PyOpResult &self) {
1515       assert(
1516           mlirOperationEqual(self.getParentOperation()->get(),
1517                              mlirOpResultGetOwner(self.get())) &&
1518           "expected the owner of the value in Python to match that in the IR");
1519       return self.getParentOperation();
1520     });
1521     c.def_property_readonly("result_number", [](PyOpResult &self) {
1522       return mlirOpResultGetResultNumber(self.get());
1523     });
1524   }
1525 };
1526 
1527 /// A list of block arguments. Internally, these are stored as consecutive
1528 /// elements, random access is cheap. The argument list is associated with the
1529 /// operation that contains the block (detached blocks are not allowed in
1530 /// Python bindings) and extends its lifetime.
1531 class PyBlockArgumentList {
1532 public:
1533   PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1534       : operation(std::move(operation)), block(block) {}
1535 
1536   /// Returns the length of the block argument list.
1537   intptr_t dunderLen() {
1538     operation->checkValid();
1539     return mlirBlockGetNumArguments(block);
1540   }
1541 
1542   /// Returns `index`-th element of the block argument list.
1543   PyBlockArgument dunderGetItem(intptr_t index) {
1544     if (index < 0 || index >= dunderLen()) {
1545       throw SetPyError(PyExc_IndexError,
1546                        "attempt to access out of bounds region");
1547     }
1548     PyValue value(operation, mlirBlockGetArgument(block, index));
1549     return PyBlockArgument(value);
1550   }
1551 
1552   /// Defines a Python class in the bindings.
1553   static void bind(py::module &m) {
1554     py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1555         .def("__len__", &PyBlockArgumentList::dunderLen)
1556         .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1557   }
1558 
1559 private:
1560   PyOperationRef operation;
1561   MlirBlock block;
1562 };
1563 
1564 /// A list of operation operands. Internally, these are stored as consecutive
1565 /// elements, random access is cheap. The result list is associated with the
1566 /// operation whose results these are, and extends the lifetime of this
1567 /// operation.
1568 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1569 public:
1570   static constexpr const char *pyClassName = "OpOperandList";
1571 
1572   PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1573                   intptr_t length = -1, intptr_t step = 1)
1574       : Sliceable(startIndex,
1575                   length == -1 ? mlirOperationGetNumOperands(operation->get())
1576                                : length,
1577                   step),
1578         operation(operation) {}
1579 
1580   intptr_t getNumElements() {
1581     operation->checkValid();
1582     return mlirOperationGetNumOperands(operation->get());
1583   }
1584 
1585   PyValue getElement(intptr_t pos) {
1586     return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1587   }
1588 
1589   PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1590     return PyOpOperandList(operation, startIndex, length, step);
1591   }
1592 
1593 private:
1594   PyOperationRef operation;
1595 };
1596 
1597 /// A list of operation results. Internally, these are stored as consecutive
1598 /// elements, random access is cheap. The result list is associated with the
1599 /// operation whose results these are, and extends the lifetime of this
1600 /// operation.
1601 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1602 public:
1603   static constexpr const char *pyClassName = "OpResultList";
1604 
1605   PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1606                  intptr_t length = -1, intptr_t step = 1)
1607       : Sliceable(startIndex,
1608                   length == -1 ? mlirOperationGetNumResults(operation->get())
1609                                : length,
1610                   step),
1611         operation(operation) {}
1612 
1613   intptr_t getNumElements() {
1614     operation->checkValid();
1615     return mlirOperationGetNumResults(operation->get());
1616   }
1617 
1618   PyOpResult getElement(intptr_t index) {
1619     PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1620     return PyOpResult(value);
1621   }
1622 
1623   PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1624     return PyOpResultList(operation, startIndex, length, step);
1625   }
1626 
1627 private:
1628   PyOperationRef operation;
1629 };
1630 
1631 /// A list of operation attributes. Can be indexed by name, producing
1632 /// attributes, or by index, producing named attributes.
1633 class PyOpAttributeMap {
1634 public:
1635   PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1636 
1637   PyAttribute dunderGetItemNamed(const std::string &name) {
1638     MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1639                                                          toMlirStringRef(name));
1640     if (mlirAttributeIsNull(attr)) {
1641       throw SetPyError(PyExc_KeyError,
1642                        "attempt to access a non-existent attribute");
1643     }
1644     return PyAttribute(operation->getContext(), attr);
1645   }
1646 
1647   PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1648     if (index < 0 || index >= dunderLen()) {
1649       throw SetPyError(PyExc_IndexError,
1650                        "attempt to access out of bounds attribute");
1651     }
1652     MlirNamedAttribute namedAttr =
1653         mlirOperationGetAttribute(operation->get(), index);
1654     return PyNamedAttribute(
1655         namedAttr.attribute,
1656         std::string(mlirIdentifierStr(namedAttr.name).data));
1657   }
1658 
1659   void dunderSetItem(const std::string &name, PyAttribute attr) {
1660     mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1661                                     attr);
1662   }
1663 
1664   void dunderDelItem(const std::string &name) {
1665     int removed = mlirOperationRemoveAttributeByName(operation->get(),
1666                                                      toMlirStringRef(name));
1667     if (!removed)
1668       throw SetPyError(PyExc_KeyError,
1669                        "attempt to delete a non-existent attribute");
1670   }
1671 
1672   intptr_t dunderLen() {
1673     return mlirOperationGetNumAttributes(operation->get());
1674   }
1675 
1676   bool dunderContains(const std::string &name) {
1677     return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1678         operation->get(), toMlirStringRef(name)));
1679   }
1680 
1681   static void bind(py::module &m) {
1682     py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1683         .def("__contains__", &PyOpAttributeMap::dunderContains)
1684         .def("__len__", &PyOpAttributeMap::dunderLen)
1685         .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1686         .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1687         .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1688         .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1689   }
1690 
1691 private:
1692   PyOperationRef operation;
1693 };
1694 
1695 } // end namespace
1696 
1697 //------------------------------------------------------------------------------
1698 // Populates the core exports of the 'ir' submodule.
1699 //------------------------------------------------------------------------------
1700 
1701 void mlir::python::populateIRCore(py::module &m) {
1702   //----------------------------------------------------------------------------
1703   // Mapping of MlirContext
1704   //----------------------------------------------------------------------------
1705   py::class_<PyMlirContext>(m, "Context")
1706       .def(py::init<>(&PyMlirContext::createNewContextForInit))
1707       .def_static("_get_live_count", &PyMlirContext::getLiveCount)
1708       .def("_get_context_again",
1709            [](PyMlirContext &self) {
1710              PyMlirContextRef ref = PyMlirContext::forContext(self.get());
1711              return ref.releaseObject();
1712            })
1713       .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
1714       .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
1715       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
1716                              &PyMlirContext::getCapsule)
1717       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
1718       .def("__enter__", &PyMlirContext::contextEnter)
1719       .def("__exit__", &PyMlirContext::contextExit)
1720       .def_property_readonly_static(
1721           "current",
1722           [](py::object & /*class*/) {
1723             auto *context = PyThreadContextEntry::getDefaultContext();
1724             if (!context)
1725               throw SetPyError(PyExc_ValueError, "No current Context");
1726             return context;
1727           },
1728           "Gets the Context bound to the current thread or raises ValueError")
1729       .def_property_readonly(
1730           "dialects",
1731           [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1732           "Gets a container for accessing dialects by name")
1733       .def_property_readonly(
1734           "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
1735           "Alias for 'dialect'")
1736       .def(
1737           "get_dialect_descriptor",
1738           [=](PyMlirContext &self, std::string &name) {
1739             MlirDialect dialect = mlirContextGetOrLoadDialect(
1740                 self.get(), {name.data(), name.size()});
1741             if (mlirDialectIsNull(dialect)) {
1742               throw SetPyError(PyExc_ValueError,
1743                                Twine("Dialect '") + name + "' not found");
1744             }
1745             return PyDialectDescriptor(self.getRef(), dialect);
1746           },
1747           "Gets or loads a dialect by name, returning its descriptor object")
1748       .def_property(
1749           "allow_unregistered_dialects",
1750           [](PyMlirContext &self) -> bool {
1751             return mlirContextGetAllowUnregisteredDialects(self.get());
1752           },
1753           [](PyMlirContext &self, bool value) {
1754             mlirContextSetAllowUnregisteredDialects(self.get(), value);
1755           });
1756 
1757   //----------------------------------------------------------------------------
1758   // Mapping of PyDialectDescriptor
1759   //----------------------------------------------------------------------------
1760   py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
1761       .def_property_readonly("namespace",
1762                              [](PyDialectDescriptor &self) {
1763                                MlirStringRef ns =
1764                                    mlirDialectGetNamespace(self.get());
1765                                return py::str(ns.data, ns.length);
1766                              })
1767       .def("__repr__", [](PyDialectDescriptor &self) {
1768         MlirStringRef ns = mlirDialectGetNamespace(self.get());
1769         std::string repr("<DialectDescriptor ");
1770         repr.append(ns.data, ns.length);
1771         repr.append(">");
1772         return repr;
1773       });
1774 
1775   //----------------------------------------------------------------------------
1776   // Mapping of PyDialects
1777   //----------------------------------------------------------------------------
1778   py::class_<PyDialects>(m, "Dialects")
1779       .def("__getitem__",
1780            [=](PyDialects &self, std::string keyName) {
1781              MlirDialect dialect =
1782                  self.getDialectForKey(keyName, /*attrError=*/false);
1783              py::object descriptor =
1784                  py::cast(PyDialectDescriptor{self.getContext(), dialect});
1785              return createCustomDialectWrapper(keyName, std::move(descriptor));
1786            })
1787       .def("__getattr__", [=](PyDialects &self, std::string attrName) {
1788         MlirDialect dialect =
1789             self.getDialectForKey(attrName, /*attrError=*/true);
1790         py::object descriptor =
1791             py::cast(PyDialectDescriptor{self.getContext(), dialect});
1792         return createCustomDialectWrapper(attrName, std::move(descriptor));
1793       });
1794 
1795   //----------------------------------------------------------------------------
1796   // Mapping of PyDialect
1797   //----------------------------------------------------------------------------
1798   py::class_<PyDialect>(m, "Dialect")
1799       .def(py::init<py::object>(), "descriptor")
1800       .def_property_readonly(
1801           "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
1802       .def("__repr__", [](py::object self) {
1803         auto clazz = self.attr("__class__");
1804         return py::str("<Dialect ") +
1805                self.attr("descriptor").attr("namespace") + py::str(" (class ") +
1806                clazz.attr("__module__") + py::str(".") +
1807                clazz.attr("__name__") + py::str(")>");
1808       });
1809 
1810   //----------------------------------------------------------------------------
1811   // Mapping of Location
1812   //----------------------------------------------------------------------------
1813   py::class_<PyLocation>(m, "Location")
1814       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
1815       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
1816       .def("__enter__", &PyLocation::contextEnter)
1817       .def("__exit__", &PyLocation::contextExit)
1818       .def("__eq__",
1819            [](PyLocation &self, PyLocation &other) -> bool {
1820              return mlirLocationEqual(self, other);
1821            })
1822       .def("__eq__", [](PyLocation &self, py::object other) { return false; })
1823       .def_property_readonly_static(
1824           "current",
1825           [](py::object & /*class*/) {
1826             auto *loc = PyThreadContextEntry::getDefaultLocation();
1827             if (!loc)
1828               throw SetPyError(PyExc_ValueError, "No current Location");
1829             return loc;
1830           },
1831           "Gets the Location bound to the current thread or raises ValueError")
1832       .def_static(
1833           "unknown",
1834           [](DefaultingPyMlirContext context) {
1835             return PyLocation(context->getRef(),
1836                               mlirLocationUnknownGet(context->get()));
1837           },
1838           py::arg("context") = py::none(),
1839           "Gets a Location representing an unknown location")
1840       .def_static(
1841           "file",
1842           [](std::string filename, int line, int col,
1843              DefaultingPyMlirContext context) {
1844             return PyLocation(
1845                 context->getRef(),
1846                 mlirLocationFileLineColGet(
1847                     context->get(), toMlirStringRef(filename), line, col));
1848           },
1849           py::arg("filename"), py::arg("line"), py::arg("col"),
1850           py::arg("context") = py::none(), kContextGetFileLocationDocstring)
1851       .def_property_readonly(
1852           "context",
1853           [](PyLocation &self) { return self.getContext().getObject(); },
1854           "Context that owns the Location")
1855       .def("__repr__", [](PyLocation &self) {
1856         PyPrintAccumulator printAccum;
1857         mlirLocationPrint(self, printAccum.getCallback(),
1858                           printAccum.getUserData());
1859         return printAccum.join();
1860       });
1861 
1862   //----------------------------------------------------------------------------
1863   // Mapping of Module
1864   //----------------------------------------------------------------------------
1865   py::class_<PyModule>(m, "Module")
1866       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
1867       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
1868       .def_static(
1869           "parse",
1870           [](const std::string moduleAsm, DefaultingPyMlirContext context) {
1871             MlirModule module = mlirModuleCreateParse(
1872                 context->get(), toMlirStringRef(moduleAsm));
1873             // TODO: Rework error reporting once diagnostic engine is exposed
1874             // in C API.
1875             if (mlirModuleIsNull(module)) {
1876               throw SetPyError(
1877                   PyExc_ValueError,
1878                   "Unable to parse module assembly (see diagnostics)");
1879             }
1880             return PyModule::forModule(module).releaseObject();
1881           },
1882           py::arg("asm"), py::arg("context") = py::none(),
1883           kModuleParseDocstring)
1884       .def_static(
1885           "create",
1886           [](DefaultingPyLocation loc) {
1887             MlirModule module = mlirModuleCreateEmpty(loc);
1888             return PyModule::forModule(module).releaseObject();
1889           },
1890           py::arg("loc") = py::none(), "Creates an empty module")
1891       .def_property_readonly(
1892           "context",
1893           [](PyModule &self) { return self.getContext().getObject(); },
1894           "Context that created the Module")
1895       .def_property_readonly(
1896           "operation",
1897           [](PyModule &self) {
1898             return PyOperation::forOperation(self.getContext(),
1899                                              mlirModuleGetOperation(self.get()),
1900                                              self.getRef().releaseObject())
1901                 .releaseObject();
1902           },
1903           "Accesses the module as an operation")
1904       .def_property_readonly(
1905           "body",
1906           [](PyModule &self) {
1907             PyOperationRef module_op = PyOperation::forOperation(
1908                 self.getContext(), mlirModuleGetOperation(self.get()),
1909                 self.getRef().releaseObject());
1910             PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
1911             return returnBlock;
1912           },
1913           "Return the block for this module")
1914       .def(
1915           "dump",
1916           [](PyModule &self) {
1917             mlirOperationDump(mlirModuleGetOperation(self.get()));
1918           },
1919           kDumpDocstring)
1920       .def(
1921           "__str__",
1922           [](PyModule &self) {
1923             MlirOperation operation = mlirModuleGetOperation(self.get());
1924             PyPrintAccumulator printAccum;
1925             mlirOperationPrint(operation, printAccum.getCallback(),
1926                                printAccum.getUserData());
1927             return printAccum.join();
1928           },
1929           kOperationStrDunderDocstring);
1930 
1931   //----------------------------------------------------------------------------
1932   // Mapping of Operation.
1933   //----------------------------------------------------------------------------
1934   py::class_<PyOperationBase>(m, "_OperationBase")
1935       .def("__eq__",
1936            [](PyOperationBase &self, PyOperationBase &other) {
1937              return &self.getOperation() == &other.getOperation();
1938            })
1939       .def("__eq__",
1940            [](PyOperationBase &self, py::object other) { return false; })
1941       .def_property_readonly("attributes",
1942                              [](PyOperationBase &self) {
1943                                return PyOpAttributeMap(
1944                                    self.getOperation().getRef());
1945                              })
1946       .def_property_readonly("operands",
1947                              [](PyOperationBase &self) {
1948                                return PyOpOperandList(
1949                                    self.getOperation().getRef());
1950                              })
1951       .def_property_readonly("regions",
1952                              [](PyOperationBase &self) {
1953                                return PyRegionList(
1954                                    self.getOperation().getRef());
1955                              })
1956       .def_property_readonly(
1957           "results",
1958           [](PyOperationBase &self) {
1959             return PyOpResultList(self.getOperation().getRef());
1960           },
1961           "Returns the list of Operation results.")
1962       .def_property_readonly(
1963           "result",
1964           [](PyOperationBase &self) {
1965             auto &operation = self.getOperation();
1966             auto numResults = mlirOperationGetNumResults(operation);
1967             if (numResults != 1) {
1968               auto name = mlirIdentifierStr(mlirOperationGetName(operation));
1969               throw SetPyError(
1970                   PyExc_ValueError,
1971                   Twine("Cannot call .result on operation ") +
1972                       StringRef(name.data, name.length) + " which has " +
1973                       Twine(numResults) +
1974                       " results (it is only valid for operations with a "
1975                       "single result)");
1976             }
1977             return PyOpResult(operation.getRef(),
1978                               mlirOperationGetResult(operation, 0));
1979           },
1980           "Shortcut to get an op result if it has only one (throws an error "
1981           "otherwise).")
1982       .def("__iter__",
1983            [](PyOperationBase &self) {
1984              return PyRegionIterator(self.getOperation().getRef());
1985            })
1986       .def(
1987           "__str__",
1988           [](PyOperationBase &self) {
1989             return self.getAsm(/*binary=*/false,
1990                                /*largeElementsLimit=*/llvm::None,
1991                                /*enableDebugInfo=*/false,
1992                                /*prettyDebugInfo=*/false,
1993                                /*printGenericOpForm=*/false,
1994                                /*useLocalScope=*/false);
1995           },
1996           "Returns the assembly form of the operation.")
1997       .def("print", &PyOperationBase::print,
1998            // Careful: Lots of arguments must match up with print method.
1999            py::arg("file") = py::none(), py::arg("binary") = false,
2000            py::arg("large_elements_limit") = py::none(),
2001            py::arg("enable_debug_info") = false,
2002            py::arg("pretty_debug_info") = false,
2003            py::arg("print_generic_op_form") = false,
2004            py::arg("use_local_scope") = false, kOperationPrintDocstring)
2005       .def("get_asm", &PyOperationBase::getAsm,
2006            // Careful: Lots of arguments must match up with get_asm method.
2007            py::arg("binary") = false,
2008            py::arg("large_elements_limit") = py::none(),
2009            py::arg("enable_debug_info") = false,
2010            py::arg("pretty_debug_info") = false,
2011            py::arg("print_generic_op_form") = false,
2012            py::arg("use_local_scope") = false, kOperationGetAsmDocstring)
2013       .def(
2014           "verify",
2015           [](PyOperationBase &self) {
2016             return mlirOperationVerify(self.getOperation());
2017           },
2018           "Verify the operation and return true if it passes, false if it "
2019           "fails.");
2020 
2021   py::class_<PyOperation, PyOperationBase>(m, "Operation")
2022       .def_static("create", &PyOperation::create, py::arg("name"),
2023                   py::arg("results") = py::none(),
2024                   py::arg("operands") = py::none(),
2025                   py::arg("attributes") = py::none(),
2026                   py::arg("successors") = py::none(), py::arg("regions") = 0,
2027                   py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2028                   kOperationCreateDocstring)
2029       .def_property_readonly("name",
2030                              [](PyOperation &self) {
2031                                MlirOperation operation = self.get();
2032                                MlirStringRef name = mlirIdentifierStr(
2033                                    mlirOperationGetName(operation));
2034                                return py::str(name.data, name.length);
2035                              })
2036       .def_property_readonly(
2037           "context",
2038           [](PyOperation &self) { return self.getContext().getObject(); },
2039           "Context that owns the Operation")
2040       .def_property_readonly("opview", &PyOperation::createOpView);
2041 
2042   auto opViewClass =
2043       py::class_<PyOpView, PyOperationBase>(m, "OpView")
2044           .def(py::init<py::object>())
2045           .def_property_readonly("operation", &PyOpView::getOperationObject)
2046           .def_property_readonly(
2047               "context",
2048               [](PyOpView &self) {
2049                 return self.getOperation().getContext().getObject();
2050               },
2051               "Context that owns the Operation")
2052           .def("__str__", [](PyOpView &self) {
2053             return py::str(self.getOperationObject());
2054           });
2055   opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
2056   opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
2057   opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
2058   opViewClass.attr("build_generic") = classmethod(
2059       &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
2060       py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
2061       py::arg("successors") = py::none(), py::arg("regions") = py::none(),
2062       py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2063       "Builds a specific, generated OpView based on class level attributes.");
2064 
2065   //----------------------------------------------------------------------------
2066   // Mapping of PyRegion.
2067   //----------------------------------------------------------------------------
2068   py::class_<PyRegion>(m, "Region")
2069       .def_property_readonly(
2070           "blocks",
2071           [](PyRegion &self) {
2072             return PyBlockList(self.getParentOperation(), self.get());
2073           },
2074           "Returns a forward-optimized sequence of blocks.")
2075       .def(
2076           "__iter__",
2077           [](PyRegion &self) {
2078             self.checkValid();
2079             MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2080             return PyBlockIterator(self.getParentOperation(), firstBlock);
2081           },
2082           "Iterates over blocks in the region.")
2083       .def("__eq__",
2084            [](PyRegion &self, PyRegion &other) {
2085              return self.get().ptr == other.get().ptr;
2086            })
2087       .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2088 
2089   //----------------------------------------------------------------------------
2090   // Mapping of PyBlock.
2091   //----------------------------------------------------------------------------
2092   py::class_<PyBlock>(m, "Block")
2093       .def_property_readonly(
2094           "arguments",
2095           [](PyBlock &self) {
2096             return PyBlockArgumentList(self.getParentOperation(), self.get());
2097           },
2098           "Returns a list of block arguments.")
2099       .def_property_readonly(
2100           "operations",
2101           [](PyBlock &self) {
2102             return PyOperationList(self.getParentOperation(), self.get());
2103           },
2104           "Returns a forward-optimized sequence of operations.")
2105       .def(
2106           "__iter__",
2107           [](PyBlock &self) {
2108             self.checkValid();
2109             MlirOperation firstOperation =
2110                 mlirBlockGetFirstOperation(self.get());
2111             return PyOperationIterator(self.getParentOperation(),
2112                                        firstOperation);
2113           },
2114           "Iterates over operations in the block.")
2115       .def("__eq__",
2116            [](PyBlock &self, PyBlock &other) {
2117              return self.get().ptr == other.get().ptr;
2118            })
2119       .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2120       .def(
2121           "__str__",
2122           [](PyBlock &self) {
2123             self.checkValid();
2124             PyPrintAccumulator printAccum;
2125             mlirBlockPrint(self.get(), printAccum.getCallback(),
2126                            printAccum.getUserData());
2127             return printAccum.join();
2128           },
2129           "Returns the assembly form of the block.");
2130 
2131   //----------------------------------------------------------------------------
2132   // Mapping of PyInsertionPoint.
2133   //----------------------------------------------------------------------------
2134 
2135   py::class_<PyInsertionPoint>(m, "InsertionPoint")
2136       .def(py::init<PyBlock &>(), py::arg("block"),
2137            "Inserts after the last operation but still inside the block.")
2138       .def("__enter__", &PyInsertionPoint::contextEnter)
2139       .def("__exit__", &PyInsertionPoint::contextExit)
2140       .def_property_readonly_static(
2141           "current",
2142           [](py::object & /*class*/) {
2143             auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2144             if (!ip)
2145               throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2146             return ip;
2147           },
2148           "Gets the InsertionPoint bound to the current thread or raises "
2149           "ValueError if none has been set")
2150       .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2151            "Inserts before a referenced operation.")
2152       .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2153                   py::arg("block"), "Inserts at the beginning of the block.")
2154       .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2155                   py::arg("block"), "Inserts before the block terminator.")
2156       .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2157            "Inserts an operation.");
2158 
2159   //----------------------------------------------------------------------------
2160   // Mapping of PyAttribute.
2161   //----------------------------------------------------------------------------
2162   py::class_<PyAttribute>(m, "Attribute")
2163       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2164                              &PyAttribute::getCapsule)
2165       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2166       .def_static(
2167           "parse",
2168           [](std::string attrSpec, DefaultingPyMlirContext context) {
2169             MlirAttribute type = mlirAttributeParseGet(
2170                 context->get(), toMlirStringRef(attrSpec));
2171             // TODO: Rework error reporting once diagnostic engine is exposed
2172             // in C API.
2173             if (mlirAttributeIsNull(type)) {
2174               throw SetPyError(PyExc_ValueError,
2175                                Twine("Unable to parse attribute: '") +
2176                                    attrSpec + "'");
2177             }
2178             return PyAttribute(context->getRef(), type);
2179           },
2180           py::arg("asm"), py::arg("context") = py::none(),
2181           "Parses an attribute from an assembly form")
2182       .def_property_readonly(
2183           "context",
2184           [](PyAttribute &self) { return self.getContext().getObject(); },
2185           "Context that owns the Attribute")
2186       .def_property_readonly("type",
2187                              [](PyAttribute &self) {
2188                                return PyType(self.getContext()->getRef(),
2189                                              mlirAttributeGetType(self));
2190                              })
2191       .def(
2192           "get_named",
2193           [](PyAttribute &self, std::string name) {
2194             return PyNamedAttribute(self, std::move(name));
2195           },
2196           py::keep_alive<0, 1>(), "Binds a name to the attribute")
2197       .def("__eq__",
2198            [](PyAttribute &self, PyAttribute &other) { return self == other; })
2199       .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
2200       .def(
2201           "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
2202           kDumpDocstring)
2203       .def(
2204           "__str__",
2205           [](PyAttribute &self) {
2206             PyPrintAccumulator printAccum;
2207             mlirAttributePrint(self, printAccum.getCallback(),
2208                                printAccum.getUserData());
2209             return printAccum.join();
2210           },
2211           "Returns the assembly form of the Attribute.")
2212       .def("__repr__", [](PyAttribute &self) {
2213         // Generally, assembly formats are not printed for __repr__ because
2214         // this can cause exceptionally long debug output and exceptions.
2215         // However, attribute values are generally considered useful and are
2216         // printed. This may need to be re-evaluated if debug dumps end up
2217         // being excessive.
2218         PyPrintAccumulator printAccum;
2219         printAccum.parts.append("Attribute(");
2220         mlirAttributePrint(self, printAccum.getCallback(),
2221                            printAccum.getUserData());
2222         printAccum.parts.append(")");
2223         return printAccum.join();
2224       });
2225 
2226   //----------------------------------------------------------------------------
2227   // Mapping of PyNamedAttribute
2228   //----------------------------------------------------------------------------
2229   py::class_<PyNamedAttribute>(m, "NamedAttribute")
2230       .def("__repr__",
2231            [](PyNamedAttribute &self) {
2232              PyPrintAccumulator printAccum;
2233              printAccum.parts.append("NamedAttribute(");
2234              printAccum.parts.append(
2235                  mlirIdentifierStr(self.namedAttr.name).data);
2236              printAccum.parts.append("=");
2237              mlirAttributePrint(self.namedAttr.attribute,
2238                                 printAccum.getCallback(),
2239                                 printAccum.getUserData());
2240              printAccum.parts.append(")");
2241              return printAccum.join();
2242            })
2243       .def_property_readonly(
2244           "name",
2245           [](PyNamedAttribute &self) {
2246             return py::str(mlirIdentifierStr(self.namedAttr.name).data,
2247                            mlirIdentifierStr(self.namedAttr.name).length);
2248           },
2249           "The name of the NamedAttribute binding")
2250       .def_property_readonly(
2251           "attr",
2252           [](PyNamedAttribute &self) {
2253             // TODO: When named attribute is removed/refactored, also remove
2254             // this constructor (it does an inefficient table lookup).
2255             auto contextRef = PyMlirContext::forContext(
2256                 mlirAttributeGetContext(self.namedAttr.attribute));
2257             return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
2258           },
2259           py::keep_alive<0, 1>(),
2260           "The underlying generic attribute of the NamedAttribute binding");
2261 
2262   //----------------------------------------------------------------------------
2263   // Mapping of PyType.
2264   //----------------------------------------------------------------------------
2265   py::class_<PyType>(m, "Type")
2266       .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
2267       .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
2268       .def_static(
2269           "parse",
2270           [](std::string typeSpec, DefaultingPyMlirContext context) {
2271             MlirType type =
2272                 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
2273             // TODO: Rework error reporting once diagnostic engine is exposed
2274             // in C API.
2275             if (mlirTypeIsNull(type)) {
2276               throw SetPyError(PyExc_ValueError,
2277                                Twine("Unable to parse type: '") + typeSpec +
2278                                    "'");
2279             }
2280             return PyType(context->getRef(), type);
2281           },
2282           py::arg("asm"), py::arg("context") = py::none(),
2283           kContextParseTypeDocstring)
2284       .def_property_readonly(
2285           "context", [](PyType &self) { return self.getContext().getObject(); },
2286           "Context that owns the Type")
2287       .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
2288       .def("__eq__", [](PyType &self, py::object &other) { return false; })
2289       .def(
2290           "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
2291       .def(
2292           "__str__",
2293           [](PyType &self) {
2294             PyPrintAccumulator printAccum;
2295             mlirTypePrint(self, printAccum.getCallback(),
2296                           printAccum.getUserData());
2297             return printAccum.join();
2298           },
2299           "Returns the assembly form of the type.")
2300       .def("__repr__", [](PyType &self) {
2301         // Generally, assembly formats are not printed for __repr__ because
2302         // this can cause exceptionally long debug output and exceptions.
2303         // However, types are an exception as they typically have compact
2304         // assembly forms and printing them is useful.
2305         PyPrintAccumulator printAccum;
2306         printAccum.parts.append("Type(");
2307         mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
2308         printAccum.parts.append(")");
2309         return printAccum.join();
2310       });
2311 
2312   //----------------------------------------------------------------------------
2313   // Mapping of Value.
2314   //----------------------------------------------------------------------------
2315   py::class_<PyValue>(m, "Value")
2316       .def_property_readonly(
2317           "context",
2318           [](PyValue &self) { return self.getParentOperation()->getContext(); },
2319           "Context in which the value lives.")
2320       .def(
2321           "dump", [](PyValue &self) { mlirValueDump(self.get()); },
2322           kDumpDocstring)
2323       .def("__eq__",
2324            [](PyValue &self, PyValue &other) {
2325              return self.get().ptr == other.get().ptr;
2326            })
2327       .def("__eq__", [](PyValue &self, py::object other) { return false; })
2328       .def(
2329           "__str__",
2330           [](PyValue &self) {
2331             PyPrintAccumulator printAccum;
2332             printAccum.parts.append("Value(");
2333             mlirValuePrint(self.get(), printAccum.getCallback(),
2334                            printAccum.getUserData());
2335             printAccum.parts.append(")");
2336             return printAccum.join();
2337           },
2338           kValueDunderStrDocstring)
2339       .def_property_readonly("type", [](PyValue &self) {
2340         return PyType(self.getParentOperation()->getContext(),
2341                       mlirValueGetType(self.get()));
2342       });
2343   PyBlockArgument::bind(m);
2344   PyOpResult::bind(m);
2345 
2346   // Container bindings.
2347   PyBlockArgumentList::bind(m);
2348   PyBlockIterator::bind(m);
2349   PyBlockList::bind(m);
2350   PyOperationIterator::bind(m);
2351   PyOperationList::bind(m);
2352   PyOpAttributeMap::bind(m);
2353   PyOpOperandList::bind(m);
2354   PyOpResultList::bind(m);
2355   PyRegionIterator::bind(m);
2356   PyRegionList::bind(m);
2357 }
2358