1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <vector>
13 
14 #include "PybindUtils.h"
15 
16 #include "mlir-c/AffineExpr.h"
17 #include "mlir-c/AffineMap.h"
18 #include "mlir-c/Diagnostics.h"
19 #include "mlir-c/IR.h"
20 #include "mlir-c/IntegerSet.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/Optional.h"
23 
24 namespace mlir {
25 namespace python {
26 
27 class PyBlock;
28 class PyDiagnostic;
29 class PyDiagnosticHandler;
30 class PyInsertionPoint;
31 class PyLocation;
32 class DefaultingPyLocation;
33 class PyMlirContext;
34 class DefaultingPyMlirContext;
35 class PyModule;
36 class PyOperation;
37 class PyType;
38 class PySymbolTable;
39 class PyValue;
40 
41 /// Template for a reference to a concrete type which captures a python
42 /// reference to its underlying python object.
43 template <typename T>
44 class PyObjectRef {
45 public:
46   PyObjectRef(T *referrent, pybind11::object object)
47       : referrent(referrent), object(std::move(object)) {
48     assert(this->referrent &&
49            "cannot construct PyObjectRef with null referrent");
50     assert(this->object && "cannot construct PyObjectRef with null object");
51   }
52   PyObjectRef(PyObjectRef &&other)
53       : referrent(other.referrent), object(std::move(other.object)) {
54     other.referrent = nullptr;
55     assert(!other.object);
56   }
57   PyObjectRef(const PyObjectRef &other)
58       : referrent(other.referrent), object(other.object /* copies */) {}
59   ~PyObjectRef() {}
60 
61   int getRefCount() {
62     if (!object)
63       return 0;
64     return object.ref_count();
65   }
66 
67   /// Releases the object held by this instance, returning it.
68   /// This is the proper thing to return from a function that wants to return
69   /// the reference. Note that this does not work from initializers.
70   pybind11::object releaseObject() {
71     assert(referrent && object);
72     referrent = nullptr;
73     auto stolen = std::move(object);
74     return stolen;
75   }
76 
77   T *get() { return referrent; }
78   T *operator->() {
79     assert(referrent && object);
80     return referrent;
81   }
82   pybind11::object getObject() {
83     assert(referrent && object);
84     return object;
85   }
86   operator bool() const { return referrent && object; }
87 
88 private:
89   T *referrent;
90   pybind11::object object;
91 };
92 
93 /// Tracks an entry in the thread context stack. New entries are pushed onto
94 /// here for each with block that activates a new InsertionPoint, Context or
95 /// Location.
96 ///
97 /// Pushing either a Location or InsertionPoint also pushes its associated
98 /// Context. Pushing a Context will not modify the Location or InsertionPoint
99 /// unless if they are from a different context, in which case, they are
100 /// cleared.
101 class PyThreadContextEntry {
102 public:
103   enum class FrameKind {
104     Context,
105     InsertionPoint,
106     Location,
107   };
108 
109   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
110                        pybind11::object insertionPoint,
111                        pybind11::object location)
112       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
113         location(std::move(location)), frameKind(frameKind) {}
114 
115   /// Gets the top of stack context and return nullptr if not defined.
116   static PyMlirContext *getDefaultContext();
117 
118   /// Gets the top of stack insertion point and return nullptr if not defined.
119   static PyInsertionPoint *getDefaultInsertionPoint();
120 
121   /// Gets the top of stack location and returns nullptr if not defined.
122   static PyLocation *getDefaultLocation();
123 
124   PyMlirContext *getContext();
125   PyInsertionPoint *getInsertionPoint();
126   PyLocation *getLocation();
127   FrameKind getFrameKind() { return frameKind; }
128 
129   /// Stack management.
130   static PyThreadContextEntry *getTopOfStack();
131   static pybind11::object pushContext(PyMlirContext &context);
132   static void popContext(PyMlirContext &context);
133   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
134   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
135   static pybind11::object pushLocation(PyLocation &location);
136   static void popLocation(PyLocation &location);
137 
138   /// Gets the thread local stack.
139   static std::vector<PyThreadContextEntry> &getStack();
140 
141 private:
142   static void push(FrameKind frameKind, pybind11::object context,
143                    pybind11::object insertionPoint, pybind11::object location);
144 
145   /// An object reference to the PyContext.
146   pybind11::object context;
147   /// An object reference to the current insertion point.
148   pybind11::object insertionPoint;
149   /// An object reference to the current location.
150   pybind11::object location;
151   // The kind of push that was performed.
152   FrameKind frameKind;
153 };
154 
155 /// Wrapper around MlirContext.
156 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
157 class PyMlirContext {
158 public:
159   PyMlirContext() = delete;
160   PyMlirContext(const PyMlirContext &) = delete;
161   PyMlirContext(PyMlirContext &&) = delete;
162 
163   /// For the case of a python __init__ (py::init) method, pybind11 is quite
164   /// strict about needing to return a pointer that is not yet associated to
165   /// an py::object. Since the forContext() method acts like a pool, possibly
166   /// returning a recycled context, it does not satisfy this need. The usual
167   /// way in python to accomplish such a thing is to override __new__, but
168   /// that is also not supported by pybind11. Instead, we use this entry
169   /// point which always constructs a fresh context (which cannot alias an
170   /// existing one because it is fresh).
171   static PyMlirContext *createNewContextForInit();
172 
173   /// Returns a context reference for the singleton PyMlirContext wrapper for
174   /// the given context.
175   static PyMlirContextRef forContext(MlirContext context);
176   ~PyMlirContext();
177 
178   /// Accesses the underlying MlirContext.
179   MlirContext get() { return context; }
180 
181   /// Gets a strong reference to this context, which will ensure it is kept
182   /// alive for the life of the reference.
183   PyMlirContextRef getRef() {
184     return PyMlirContextRef(this, pybind11::cast(this));
185   }
186 
187   /// Gets a capsule wrapping the void* within the MlirContext.
188   pybind11::object getCapsule();
189 
190   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
191   /// Note that PyMlirContext instances are uniqued, so the returned object
192   /// may be a pre-existing object. Ownership of the underlying MlirContext
193   /// is taken by calling this function.
194   static pybind11::object createFromCapsule(pybind11::object capsule);
195 
196   /// Gets the count of live context objects. Used for testing.
197   static size_t getLiveCount();
198 
199   /// Gets the count of live operations associated with this context.
200   /// Used for testing.
201   size_t getLiveOperationCount();
202 
203   /// Gets the count of live modules associated with this context.
204   /// Used for testing.
205   size_t getLiveModuleCount();
206 
207   /// Enter and exit the context manager.
208   pybind11::object contextEnter();
209   void contextExit(const pybind11::object &excType,
210                    const pybind11::object &excVal,
211                    const pybind11::object &excTb);
212 
213   /// Attaches a Python callback as a diagnostic handler, returning a
214   /// registration object (internally a PyDiagnosticHandler).
215   pybind11::object attachDiagnosticHandler(pybind11::object callback);
216 
217 private:
218   PyMlirContext(MlirContext context);
219   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
220   // preserving the relationship that an MlirContext maps to a single
221   // PyMlirContext wrapper. This could be replaced in the future with an
222   // extension mechanism on the MlirContext for stashing user pointers.
223   // Note that this holds a handle, which does not imply ownership.
224   // Mappings will be removed when the context is destructed.
225   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
226   static LiveContextMap &getLiveContexts();
227 
228   // Interns all live modules associated with this context. Modules tracked
229   // in this map are valid. When a module is invalidated, it is removed
230   // from this map, and while it still exists as an instance, any
231   // attempt to access it will raise an error.
232   using LiveModuleMap =
233       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
234   LiveModuleMap liveModules;
235 
236   // Interns all live operations associated with this context. Operations
237   // tracked in this map are valid. When an operation is invalidated, it is
238   // removed from this map, and while it still exists as an instance, any
239   // attempt to access it will raise an error.
240   using LiveOperationMap =
241       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
242   LiveOperationMap liveOperations;
243 
244   MlirContext context;
245   friend class PyModule;
246   friend class PyOperation;
247 };
248 
249 /// Used in function arguments when None should resolve to the current context
250 /// manager set instance.
251 class DefaultingPyMlirContext
252     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
253 public:
254   using Defaulting::Defaulting;
255   static constexpr const char kTypeDescription[] = "mlir.ir.Context";
256   static PyMlirContext &resolve();
257 };
258 
259 /// Base class for all objects that directly or indirectly depend on an
260 /// MlirContext. The lifetime of the context will extend at least to the
261 /// lifetime of these instances.
262 /// Immutable objects that depend on a context extend this directly.
263 class BaseContextObject {
264 public:
265   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
266     assert(this->contextRef &&
267            "context object constructed with null context ref");
268   }
269 
270   /// Accesses the context reference.
271   PyMlirContextRef &getContext() { return contextRef; }
272 
273 private:
274   PyMlirContextRef contextRef;
275 };
276 
277 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
278 /// are only valid for the duration of a diagnostic callback and attempting
279 /// to access them outside of that will raise an exception. This applies to
280 /// nested diagnostics (in the notes) as well.
281 class PyDiagnostic {
282 public:
283   PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
284   void invalidate();
285   bool isValid() { return valid; }
286   MlirDiagnosticSeverity getSeverity();
287   PyLocation getLocation();
288   pybind11::str getMessage();
289   pybind11::tuple getNotes();
290 
291 private:
292   MlirDiagnostic diagnostic;
293 
294   void checkValid();
295   /// If notes have been materialized from the diagnostic, then this will
296   /// be populated with the corresponding objects (all castable to
297   /// PyDiagnostic).
298   llvm::Optional<pybind11::tuple> materializedNotes;
299   bool valid = true;
300 };
301 
302 /// Represents a diagnostic handler attached to the context. The handler's
303 /// callback will be invoked with PyDiagnostic instances until the detach()
304 /// method is called or the context is destroyed. A diagnostic handler can be
305 /// the subject of a `with` block, which will detach it when the block exits.
306 ///
307 /// Since diagnostic handlers can call back into Python code which can do
308 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
309 /// etc), this is generally not deemed to be a great user-level API. Users
310 /// should generally use some form of DiagnosticCollector. If the handler raises
311 /// any exceptions, they will just be emitted to stderr and dropped.
312 ///
313 /// The unique usage of this class means that its lifetime management is
314 /// different from most other parts of the API. Instances are always created
315 /// in an attached state and can transition to a detached state by either:
316 ///   a) The context being destroyed and unregistering all handlers.
317 ///   b) An explicit call to detach().
318 /// The object may remain live from a Python perspective for an arbitrary time
319 /// after detachment, but there is nothing the user can do with it (since there
320 /// is no way to attach an existing handler object).
321 class PyDiagnosticHandler {
322 public:
323   PyDiagnosticHandler(MlirContext context, pybind11::object callback);
324   ~PyDiagnosticHandler();
325 
326   bool isAttached() { return registeredID.hasValue(); }
327   bool getHadError() { return hadError; }
328 
329   /// Detaches the handler. Does nothing if not attached.
330   void detach();
331 
332   pybind11::object contextEnter() { return pybind11::cast(this); }
333   void contextExit(pybind11::object excType, pybind11::object excVal,
334                    pybind11::object excTb) {
335     detach();
336   }
337 
338 private:
339   MlirContext context;
340   pybind11::object callback;
341   llvm::Optional<MlirDiagnosticHandlerID> registeredID;
342   bool hadError = false;
343   friend class PyMlirContext;
344 };
345 
346 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
347 /// order to differentiate it from the `Dialect` base class which is extended by
348 /// plugins which extend dialect functionality through extension python code.
349 /// This should be seen as the "low-level" object and `Dialect` as the
350 /// high-level, user facing object.
351 class PyDialectDescriptor : public BaseContextObject {
352 public:
353   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
354       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
355 
356   MlirDialect get() { return dialect; }
357 
358 private:
359   MlirDialect dialect;
360 };
361 
362 /// User-level object for accessing dialects with dotted syntax such as:
363 ///   ctx.dialect.std
364 class PyDialects : public BaseContextObject {
365 public:
366   PyDialects(PyMlirContextRef contextRef)
367       : BaseContextObject(std::move(contextRef)) {}
368 
369   MlirDialect getDialectForKey(const std::string &key, bool attrError);
370 };
371 
372 /// User-level dialect object. For dialects that have a registered extension,
373 /// this will be the base class of the extension dialect type. For un-extended,
374 /// objects of this type will be returned directly.
375 class PyDialect {
376 public:
377   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
378 
379   pybind11::object getDescriptor() { return descriptor; }
380 
381 private:
382   pybind11::object descriptor;
383 };
384 
385 /// Wrapper around an MlirLocation.
386 class PyLocation : public BaseContextObject {
387 public:
388   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
389       : BaseContextObject(std::move(contextRef)), loc(loc) {}
390 
391   operator MlirLocation() const { return loc; }
392   MlirLocation get() const { return loc; }
393 
394   /// Enter and exit the context manager.
395   pybind11::object contextEnter();
396   void contextExit(const pybind11::object &excType,
397                    const pybind11::object &excVal,
398                    const pybind11::object &excTb);
399 
400   /// Gets a capsule wrapping the void* within the MlirLocation.
401   pybind11::object getCapsule();
402 
403   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
404   /// Note that PyLocation instances are uniqued, so the returned object
405   /// may be a pre-existing object. Ownership of the underlying MlirLocation
406   /// is taken by calling this function.
407   static PyLocation createFromCapsule(pybind11::object capsule);
408 
409 private:
410   MlirLocation loc;
411 };
412 
413 /// Used in function arguments when None should resolve to the current context
414 /// manager set instance.
415 class DefaultingPyLocation
416     : public Defaulting<DefaultingPyLocation, PyLocation> {
417 public:
418   using Defaulting::Defaulting;
419   static constexpr const char kTypeDescription[] = "mlir.ir.Location";
420   static PyLocation &resolve();
421 
422   operator MlirLocation() const { return *get(); }
423 };
424 
425 /// Wrapper around MlirModule.
426 /// This is the top-level, user-owned object that contains regions/ops/blocks.
427 class PyModule;
428 using PyModuleRef = PyObjectRef<PyModule>;
429 class PyModule : public BaseContextObject {
430 public:
431   /// Returns a PyModule reference for the given MlirModule. This may return
432   /// a pre-existing or new object.
433   static PyModuleRef forModule(MlirModule module);
434   PyModule(PyModule &) = delete;
435   PyModule(PyMlirContext &&) = delete;
436   ~PyModule();
437 
438   /// Gets the backing MlirModule.
439   MlirModule get() { return module; }
440 
441   /// Gets a strong reference to this module.
442   PyModuleRef getRef() {
443     return PyModuleRef(this,
444                        pybind11::reinterpret_borrow<pybind11::object>(handle));
445   }
446 
447   /// Gets a capsule wrapping the void* within the MlirModule.
448   /// Note that the module does not (yet) provide a corresponding factory for
449   /// constructing from a capsule as that would require uniquing PyModule
450   /// instances, which is not currently done.
451   pybind11::object getCapsule();
452 
453   /// Creates a PyModule from the MlirModule wrapped by a capsule.
454   /// Note that PyModule instances are uniqued, so the returned object
455   /// may be a pre-existing object. Ownership of the underlying MlirModule
456   /// is taken by calling this function.
457   static pybind11::object createFromCapsule(pybind11::object capsule);
458 
459 private:
460   PyModule(PyMlirContextRef contextRef, MlirModule module);
461   MlirModule module;
462   pybind11::handle handle;
463 };
464 
465 /// Base class for PyOperation and PyOpView which exposes the primary, user
466 /// visible methods for manipulating it.
467 class PyOperationBase {
468 public:
469   virtual ~PyOperationBase() = default;
470   /// Implements the bound 'print' method and helps with others.
471   void print(pybind11::object fileObject, bool binary,
472              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
473              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
474              bool assumeVerified);
475   pybind11::object getAsm(bool binary,
476                           llvm::Optional<int64_t> largeElementsLimit,
477                           bool enableDebugInfo, bool prettyDebugInfo,
478                           bool printGenericOpForm, bool useLocalScope,
479                           bool assumeVerified);
480 
481   /// Moves the operation before or after the other operation.
482   void moveAfter(PyOperationBase &other);
483   void moveBefore(PyOperationBase &other);
484 
485   /// Each must provide access to the raw Operation.
486   virtual PyOperation &getOperation() = 0;
487 };
488 
489 /// Wrapper around PyOperation.
490 /// Operations exist in either an attached (dependent) or detached (top-level)
491 /// state. In the detached state (as on creation), an operation is owned by
492 /// the creator and its lifetime extends either until its reference count
493 /// drops to zero or it is attached to a parent, at which point its lifetime
494 /// is bounded by its top-level parent reference.
495 class PyOperation;
496 using PyOperationRef = PyObjectRef<PyOperation>;
497 class PyOperation : public PyOperationBase, public BaseContextObject {
498 public:
499   ~PyOperation();
500   PyOperation &getOperation() override { return *this; }
501 
502   /// Returns a PyOperation for the given MlirOperation, optionally associating
503   /// it with a parentKeepAlive.
504   static PyOperationRef
505   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
506                pybind11::object parentKeepAlive = pybind11::object());
507 
508   /// Creates a detached operation. The operation must not be associated with
509   /// any existing live operation.
510   static PyOperationRef
511   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
512                  pybind11::object parentKeepAlive = pybind11::object());
513 
514   /// Detaches the operation from its parent block and updates its state
515   /// accordingly.
516   void detachFromParent() {
517     mlirOperationRemoveFromParent(getOperation());
518     setDetached();
519     parentKeepAlive = pybind11::object();
520   }
521 
522   /// Gets the backing operation.
523   operator MlirOperation() const { return get(); }
524   MlirOperation get() const {
525     checkValid();
526     return operation;
527   }
528 
529   PyOperationRef getRef() {
530     return PyOperationRef(
531         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
532   }
533 
534   bool isAttached() { return attached; }
535   void setAttached(pybind11::object parent = pybind11::object()) {
536     assert(!attached && "operation already attached");
537     attached = true;
538   }
539   void setDetached() {
540     assert(attached && "operation already detached");
541     attached = false;
542   }
543   void checkValid() const;
544 
545   /// Gets the owning block or raises an exception if the operation has no
546   /// owning block.
547   PyBlock getBlock();
548 
549   /// Gets the parent operation or raises an exception if the operation has
550   /// no parent.
551   llvm::Optional<PyOperationRef> getParentOperation();
552 
553   /// Gets a capsule wrapping the void* within the MlirOperation.
554   pybind11::object getCapsule();
555 
556   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
557   /// Ownership of the underlying MlirOperation is taken by calling this
558   /// function.
559   static pybind11::object createFromCapsule(pybind11::object capsule);
560 
561   /// Creates an operation. See corresponding python docstring.
562   static pybind11::object
563   create(const std::string &name, llvm::Optional<std::vector<PyType *>> results,
564          llvm::Optional<std::vector<PyValue *>> operands,
565          llvm::Optional<pybind11::dict> attributes,
566          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
567          DefaultingPyLocation location, const pybind11::object &ip);
568 
569   /// Creates an OpView suitable for this operation.
570   pybind11::object createOpView();
571 
572   /// Erases the underlying MlirOperation, removes its pointer from the
573   /// parent context's live operations map, and sets the valid bit false.
574   void erase();
575 
576 private:
577   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
578   static PyOperationRef createInstance(PyMlirContextRef contextRef,
579                                        MlirOperation operation,
580                                        pybind11::object parentKeepAlive);
581 
582   MlirOperation operation;
583   pybind11::handle handle;
584   // Keeps the parent alive, regardless of whether it is an Operation or
585   // Module.
586   // TODO: As implemented, this facility is only sufficient for modeling the
587   // trivial module parent back-reference. Generalize this to also account for
588   // transitions from detached to attached and address TODOs in the
589   // ir_operation.py regarding testing corresponding lifetime guarantees.
590   pybind11::object parentKeepAlive;
591   bool attached = true;
592   bool valid = true;
593 
594   friend class PyOperationBase;
595   friend class PySymbolTable;
596 };
597 
598 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
599 /// providing more instance-specific accessors and serve as the base class for
600 /// custom ODS-style operation classes. Since this class is subclass on the
601 /// python side, it must present an __init__ method that operates in pure
602 /// python types.
603 class PyOpView : public PyOperationBase {
604 public:
605   PyOpView(const pybind11::object &operationObject);
606   PyOperation &getOperation() override { return operation; }
607 
608   static pybind11::object createRawSubclass(const pybind11::object &userClass);
609 
610   pybind11::object getOperationObject() { return operationObject; }
611 
612   static pybind11::object
613   buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
614                pybind11::list operandList,
615                llvm::Optional<pybind11::dict> attributes,
616                llvm::Optional<std::vector<PyBlock *>> successors,
617                llvm::Optional<int> regions, DefaultingPyLocation location,
618                const pybind11::object &maybeIp);
619 
620 private:
621   PyOperation &operation;           // For efficient, cast-free access from C++
622   pybind11::object operationObject; // Holds the reference.
623 };
624 
625 /// Wrapper around an MlirRegion.
626 /// Regions are managed completely by their containing operation. Unlike the
627 /// C++ API, the python API does not support detached regions.
628 class PyRegion {
629 public:
630   PyRegion(PyOperationRef parentOperation, MlirRegion region)
631       : parentOperation(std::move(parentOperation)), region(region) {
632     assert(!mlirRegionIsNull(region) && "python region cannot be null");
633   }
634   operator MlirRegion() const { return region; }
635 
636   MlirRegion get() { return region; }
637   PyOperationRef &getParentOperation() { return parentOperation; }
638 
639   void checkValid() { return parentOperation->checkValid(); }
640 
641 private:
642   PyOperationRef parentOperation;
643   MlirRegion region;
644 };
645 
646 /// Wrapper around an MlirBlock.
647 /// Blocks are managed completely by their containing operation. Unlike the
648 /// C++ API, the python API does not support detached blocks.
649 class PyBlock {
650 public:
651   PyBlock(PyOperationRef parentOperation, MlirBlock block)
652       : parentOperation(std::move(parentOperation)), block(block) {
653     assert(!mlirBlockIsNull(block) && "python block cannot be null");
654   }
655 
656   MlirBlock get() { return block; }
657   PyOperationRef &getParentOperation() { return parentOperation; }
658 
659   void checkValid() { return parentOperation->checkValid(); }
660 
661 private:
662   PyOperationRef parentOperation;
663   MlirBlock block;
664 };
665 
666 /// An insertion point maintains a pointer to a Block and a reference operation.
667 /// Calls to insert() will insert a new operation before the
668 /// reference operation. If the reference operation is null, then appends to
669 /// the end of the block.
670 class PyInsertionPoint {
671 public:
672   /// Creates an insertion point positioned after the last operation in the
673   /// block, but still inside the block.
674   PyInsertionPoint(PyBlock &block);
675   /// Creates an insertion point positioned before a reference operation.
676   PyInsertionPoint(PyOperationBase &beforeOperationBase);
677 
678   /// Shortcut to create an insertion point at the beginning of the block.
679   static PyInsertionPoint atBlockBegin(PyBlock &block);
680   /// Shortcut to create an insertion point before the block terminator.
681   static PyInsertionPoint atBlockTerminator(PyBlock &block);
682 
683   /// Inserts an operation.
684   void insert(PyOperationBase &operationBase);
685 
686   /// Enter and exit the context manager.
687   pybind11::object contextEnter();
688   void contextExit(const pybind11::object &excType,
689                    const pybind11::object &excVal,
690                    const pybind11::object &excTb);
691 
692   PyBlock &getBlock() { return block; }
693 
694 private:
695   // Trampoline constructor that avoids null initializing members while
696   // looking up parents.
697   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
698       : refOperation(std::move(refOperation)), block(std::move(block)) {}
699 
700   llvm::Optional<PyOperationRef> refOperation;
701   PyBlock block;
702 };
703 /// Wrapper around the generic MlirType.
704 /// The lifetime of a type is bound by the PyContext that created it.
705 class PyType : public BaseContextObject {
706 public:
707   PyType(PyMlirContextRef contextRef, MlirType type)
708       : BaseContextObject(std::move(contextRef)), type(type) {}
709   bool operator==(const PyType &other);
710   operator MlirType() const { return type; }
711   MlirType get() const { return type; }
712 
713   /// Gets a capsule wrapping the void* within the MlirType.
714   pybind11::object getCapsule();
715 
716   /// Creates a PyType from the MlirType wrapped by a capsule.
717   /// Note that PyType instances are uniqued, so the returned object
718   /// may be a pre-existing object. Ownership of the underlying MlirType
719   /// is taken by calling this function.
720   static PyType createFromCapsule(pybind11::object capsule);
721 
722 private:
723   MlirType type;
724 };
725 
726 /// CRTP base classes for Python types that subclass Type and should be
727 /// castable from it (i.e. via something like IntegerType(t)).
728 /// By default, type class hierarchies are one level deep (i.e. a
729 /// concrete type class extends PyType); however, intermediate python-visible
730 /// base classes can be modeled by specifying a BaseTy.
731 template <typename DerivedTy, typename BaseTy = PyType>
732 class PyConcreteType : public BaseTy {
733 public:
734   // Derived classes must define statics for:
735   //   IsAFunctionTy isaFunction
736   //   const char *pyClassName
737   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
738   using IsAFunctionTy = bool (*)(MlirType);
739 
740   PyConcreteType() = default;
741   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
742       : BaseTy(std::move(contextRef), t) {}
743   PyConcreteType(PyType &orig)
744       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
745 
746   static MlirType castFrom(PyType &orig) {
747     if (!DerivedTy::isaFunction(orig)) {
748       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
749       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
750                                              DerivedTy::pyClassName +
751                                              " (from " + origRepr + ")");
752     }
753     return orig;
754   }
755 
756   static void bind(pybind11::module &m) {
757     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
758     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
759             pybind11::arg("cast_from_type"));
760     cls.def_static(
761         "isinstance",
762         [](PyType &otherType) -> bool {
763           return DerivedTy::isaFunction(otherType);
764         },
765         pybind11::arg("other"));
766     DerivedTy::bindDerived(cls);
767   }
768 
769   /// Implemented by derived classes to add methods to the Python subclass.
770   static void bindDerived(ClassTy &m) {}
771 };
772 
773 /// Wrapper around the generic MlirAttribute.
774 /// The lifetime of a type is bound by the PyContext that created it.
775 class PyAttribute : public BaseContextObject {
776 public:
777   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
778       : BaseContextObject(std::move(contextRef)), attr(attr) {}
779   bool operator==(const PyAttribute &other);
780   operator MlirAttribute() const { return attr; }
781   MlirAttribute get() const { return attr; }
782 
783   /// Gets a capsule wrapping the void* within the MlirAttribute.
784   pybind11::object getCapsule();
785 
786   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
787   /// Note that PyAttribute instances are uniqued, so the returned object
788   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
789   /// is taken by calling this function.
790   static PyAttribute createFromCapsule(pybind11::object capsule);
791 
792 private:
793   MlirAttribute attr;
794 };
795 
796 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
797 /// TODO: Refactor this and the C-API to be based on an Identifier owned
798 /// by the context so as to avoid ownership issues here.
799 class PyNamedAttribute {
800 public:
801   /// Constructs a PyNamedAttr that retains an owned name. This should be
802   /// used in any code that originates an MlirNamedAttribute from a python
803   /// string.
804   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
805   /// passed attribute.
806   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
807 
808   MlirNamedAttribute namedAttr;
809 
810 private:
811   // Since the MlirNamedAttr contains an internal pointer to the actual
812   // memory of the owned string, it must be heap allocated to remain valid.
813   // Otherwise, strings that fit within the small object optimization threshold
814   // will have their memory address change as the containing object is moved,
815   // resulting in an invalid aliased pointer.
816   std::unique_ptr<std::string> ownedName;
817 };
818 
819 /// CRTP base classes for Python attributes that subclass Attribute and should
820 /// be castable from it (i.e. via something like StringAttr(attr)).
821 /// By default, attribute class hierarchies are one level deep (i.e. a
822 /// concrete attribute class extends PyAttribute); however, intermediate
823 /// python-visible base classes can be modeled by specifying a BaseTy.
824 template <typename DerivedTy, typename BaseTy = PyAttribute>
825 class PyConcreteAttribute : public BaseTy {
826 public:
827   // Derived classes must define statics for:
828   //   IsAFunctionTy isaFunction
829   //   const char *pyClassName
830   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
831   using IsAFunctionTy = bool (*)(MlirAttribute);
832 
833   PyConcreteAttribute() = default;
834   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
835       : BaseTy(std::move(contextRef), attr) {}
836   PyConcreteAttribute(PyAttribute &orig)
837       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
838 
839   static MlirAttribute castFrom(PyAttribute &orig) {
840     if (!DerivedTy::isaFunction(orig)) {
841       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
842       throw SetPyError(PyExc_ValueError,
843                        llvm::Twine("Cannot cast attribute to ") +
844                            DerivedTy::pyClassName + " (from " + origRepr + ")");
845     }
846     return orig;
847   }
848 
849   static void bind(pybind11::module &m) {
850     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
851                        pybind11::module_local());
852     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
853             pybind11::arg("cast_from_attr"));
854     cls.def_static(
855         "isinstance",
856         [](PyAttribute &otherAttr) -> bool {
857           return DerivedTy::isaFunction(otherAttr);
858         },
859         pybind11::arg("other"));
860     cls.def_property_readonly("type", [](PyAttribute &attr) {
861       return PyType(attr.getContext(), mlirAttributeGetType(attr));
862     });
863     DerivedTy::bindDerived(cls);
864   }
865 
866   /// Implemented by derived classes to add methods to the Python subclass.
867   static void bindDerived(ClassTy &m) {}
868 };
869 
870 /// Wrapper around the generic MlirValue.
871 /// Values are managed completely by the operation that resulted in their
872 /// definition. For op result value, this is the operation that defines the
873 /// value. For block argument values, this is the operation that contains the
874 /// block to which the value is an argument (blocks cannot be detached in Python
875 /// bindings so such operation always exists).
876 class PyValue {
877 public:
878   PyValue(PyOperationRef parentOperation, MlirValue value)
879       : parentOperation(parentOperation), value(value) {}
880   operator MlirValue() const { return value; }
881 
882   MlirValue get() { return value; }
883   PyOperationRef &getParentOperation() { return parentOperation; }
884 
885   void checkValid() { return parentOperation->checkValid(); }
886 
887   /// Gets a capsule wrapping the void* within the MlirValue.
888   pybind11::object getCapsule();
889 
890   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
891   /// the underlying MlirValue is still tied to the owning operation.
892   static PyValue createFromCapsule(pybind11::object capsule);
893 
894 private:
895   PyOperationRef parentOperation;
896   MlirValue value;
897 };
898 
899 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
900 class PyAffineExpr : public BaseContextObject {
901 public:
902   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
903       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
904   bool operator==(const PyAffineExpr &other);
905   operator MlirAffineExpr() const { return affineExpr; }
906   MlirAffineExpr get() const { return affineExpr; }
907 
908   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
909   pybind11::object getCapsule();
910 
911   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
912   /// Note that PyAffineExpr instances are uniqued, so the returned object
913   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
914   /// is taken by calling this function.
915   static PyAffineExpr createFromCapsule(pybind11::object capsule);
916 
917   PyAffineExpr add(const PyAffineExpr &other) const;
918   PyAffineExpr mul(const PyAffineExpr &other) const;
919   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
920   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
921   PyAffineExpr mod(const PyAffineExpr &other) const;
922 
923 private:
924   MlirAffineExpr affineExpr;
925 };
926 
927 class PyAffineMap : public BaseContextObject {
928 public:
929   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
930       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
931   bool operator==(const PyAffineMap &other);
932   operator MlirAffineMap() const { return affineMap; }
933   MlirAffineMap get() const { return affineMap; }
934 
935   /// Gets a capsule wrapping the void* within the MlirAffineMap.
936   pybind11::object getCapsule();
937 
938   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
939   /// Note that PyAffineMap instances are uniqued, so the returned object
940   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
941   /// is taken by calling this function.
942   static PyAffineMap createFromCapsule(pybind11::object capsule);
943 
944 private:
945   MlirAffineMap affineMap;
946 };
947 
948 class PyIntegerSet : public BaseContextObject {
949 public:
950   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
951       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
952   bool operator==(const PyIntegerSet &other);
953   operator MlirIntegerSet() const { return integerSet; }
954   MlirIntegerSet get() const { return integerSet; }
955 
956   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
957   pybind11::object getCapsule();
958 
959   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
960   /// Note that PyIntegerSet instances may be uniqued, so the returned object
961   /// may be a pre-existing object. Integer sets are owned by the context.
962   static PyIntegerSet createFromCapsule(pybind11::object capsule);
963 
964 private:
965   MlirIntegerSet integerSet;
966 };
967 
968 /// Bindings for MLIR symbol tables.
969 class PySymbolTable {
970 public:
971   /// Constructs a symbol table for the given operation.
972   explicit PySymbolTable(PyOperationBase &operation);
973 
974   /// Destroys the symbol table.
975   ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
976 
977   /// Returns the symbol (opview) with the given name, throws if there is no
978   /// such symbol in the table.
979   pybind11::object dunderGetItem(const std::string &name);
980 
981   /// Removes the given operation from the symbol table and erases it.
982   void erase(PyOperationBase &symbol);
983 
984   /// Removes the operation with the given name from the symbol table and erases
985   /// it, throws if there is no such symbol in the table.
986   void dunderDel(const std::string &name);
987 
988   /// Inserts the given operation into the symbol table. The operation must have
989   /// the symbol trait.
990   PyAttribute insert(PyOperationBase &symbol);
991 
992   /// Gets and sets the name of a symbol op.
993   static PyAttribute getSymbolName(PyOperationBase &symbol);
994   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
995 
996   /// Gets and sets the visibility of a symbol op.
997   static PyAttribute getVisibility(PyOperationBase &symbol);
998   static void setVisibility(PyOperationBase &symbol,
999                             const std::string &visibility);
1000 
1001   /// Replaces all symbol uses within an operation. See the API
1002   /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
1003   static void replaceAllSymbolUses(const std::string &oldSymbol,
1004                                    const std::string &newSymbol,
1005                                    PyOperationBase &from);
1006 
1007   /// Walks all symbol tables under and including 'from'.
1008   static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
1009                                pybind11::object callback);
1010 
1011   /// Casts the bindings class into the C API structure.
1012   operator MlirSymbolTable() { return symbolTable; }
1013 
1014 private:
1015   PyOperationRef operation;
1016   MlirSymbolTable symbolTable;
1017 };
1018 
1019 void populateIRAffine(pybind11::module &m);
1020 void populateIRAttributes(pybind11::module &m);
1021 void populateIRCore(pybind11::module &m);
1022 void populateIRInterfaces(pybind11::module &m);
1023 void populateIRTypes(pybind11::module &m);
1024 
1025 } // namespace python
1026 } // namespace mlir
1027 
1028 namespace pybind11 {
1029 namespace detail {
1030 
1031 template <>
1032 struct type_caster<mlir::python::DefaultingPyMlirContext>
1033     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
1034 template <>
1035 struct type_caster<mlir::python::DefaultingPyLocation>
1036     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
1037 
1038 } // namespace detail
1039 } // namespace pybind11
1040 
1041 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
1042