1436c6c9cSStella Laurenzo //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2436c6c9cSStella Laurenzo //
3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information.
5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6436c6c9cSStella Laurenzo //
7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===//
8436c6c9cSStella Laurenzo 
9436c6c9cSStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10436c6c9cSStella Laurenzo #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11436c6c9cSStella Laurenzo 
12e8d07395SMehdi Amini #include <utility>
13436c6c9cSStella Laurenzo #include <vector>
14436c6c9cSStella Laurenzo 
15436c6c9cSStella Laurenzo #include "PybindUtils.h"
16436c6c9cSStella Laurenzo 
17436c6c9cSStella Laurenzo #include "mlir-c/AffineExpr.h"
18436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h"
197ee25bc5SStella Laurenzo #include "mlir-c/Diagnostics.h"
20436c6c9cSStella Laurenzo #include "mlir-c/IR.h"
21436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h"
22436c6c9cSStella Laurenzo #include "llvm/ADT/DenseMap.h"
231689dadeSJohn Demme #include "llvm/ADT/Optional.h"
24436c6c9cSStella Laurenzo 
25436c6c9cSStella Laurenzo namespace mlir {
26436c6c9cSStella Laurenzo namespace python {
27436c6c9cSStella Laurenzo 
28436c6c9cSStella Laurenzo class PyBlock;
297ee25bc5SStella Laurenzo class PyDiagnostic;
307ee25bc5SStella Laurenzo class PyDiagnosticHandler;
31436c6c9cSStella Laurenzo class PyInsertionPoint;
32436c6c9cSStella Laurenzo class PyLocation;
33436c6c9cSStella Laurenzo class DefaultingPyLocation;
34436c6c9cSStella Laurenzo class PyMlirContext;
35436c6c9cSStella Laurenzo class DefaultingPyMlirContext;
36436c6c9cSStella Laurenzo class PyModule;
37436c6c9cSStella Laurenzo class PyOperation;
38436c6c9cSStella Laurenzo class PyType;
3930d61893SAlex Zinenko class PySymbolTable;
40436c6c9cSStella Laurenzo class PyValue;
41436c6c9cSStella Laurenzo 
42436c6c9cSStella Laurenzo /// Template for a reference to a concrete type which captures a python
43436c6c9cSStella Laurenzo /// reference to its underlying python object.
44436c6c9cSStella Laurenzo template <typename T>
45436c6c9cSStella Laurenzo class PyObjectRef {
46436c6c9cSStella Laurenzo public:
PyObjectRef(T * referrent,pybind11::object object)47436c6c9cSStella Laurenzo   PyObjectRef(T *referrent, pybind11::object object)
48436c6c9cSStella Laurenzo       : referrent(referrent), object(std::move(object)) {
49436c6c9cSStella Laurenzo     assert(this->referrent &&
50436c6c9cSStella Laurenzo            "cannot construct PyObjectRef with null referrent");
51436c6c9cSStella Laurenzo     assert(this->object && "cannot construct PyObjectRef with null object");
52436c6c9cSStella Laurenzo   }
PyObjectRef(PyObjectRef && other)53436c6c9cSStella Laurenzo   PyObjectRef(PyObjectRef &&other)
54436c6c9cSStella Laurenzo       : referrent(other.referrent), object(std::move(other.object)) {
55436c6c9cSStella Laurenzo     other.referrent = nullptr;
56436c6c9cSStella Laurenzo     assert(!other.object);
57436c6c9cSStella Laurenzo   }
PyObjectRef(const PyObjectRef & other)58436c6c9cSStella Laurenzo   PyObjectRef(const PyObjectRef &other)
59436c6c9cSStella Laurenzo       : referrent(other.referrent), object(other.object /* copies */) {}
609940dcfaSMehdi Amini   ~PyObjectRef() = default;
61436c6c9cSStella Laurenzo 
getRefCount()62436c6c9cSStella Laurenzo   int getRefCount() {
63436c6c9cSStella Laurenzo     if (!object)
64436c6c9cSStella Laurenzo       return 0;
65436c6c9cSStella Laurenzo     return object.ref_count();
66436c6c9cSStella Laurenzo   }
67436c6c9cSStella Laurenzo 
68436c6c9cSStella Laurenzo   /// Releases the object held by this instance, returning it.
69436c6c9cSStella Laurenzo   /// This is the proper thing to return from a function that wants to return
70436c6c9cSStella Laurenzo   /// the reference. Note that this does not work from initializers.
releaseObject()71436c6c9cSStella Laurenzo   pybind11::object releaseObject() {
72436c6c9cSStella Laurenzo     assert(referrent && object);
73436c6c9cSStella Laurenzo     referrent = nullptr;
74436c6c9cSStella Laurenzo     auto stolen = std::move(object);
75436c6c9cSStella Laurenzo     return stolen;
76436c6c9cSStella Laurenzo   }
77436c6c9cSStella Laurenzo 
get()78436c6c9cSStella Laurenzo   T *get() { return referrent; }
79436c6c9cSStella Laurenzo   T *operator->() {
80436c6c9cSStella Laurenzo     assert(referrent && object);
81436c6c9cSStella Laurenzo     return referrent;
82436c6c9cSStella Laurenzo   }
getObject()83436c6c9cSStella Laurenzo   pybind11::object getObject() {
84436c6c9cSStella Laurenzo     assert(referrent && object);
85436c6c9cSStella Laurenzo     return object;
86436c6c9cSStella Laurenzo   }
87436c6c9cSStella Laurenzo   operator bool() const { return referrent && object; }
88436c6c9cSStella Laurenzo 
89436c6c9cSStella Laurenzo private:
90436c6c9cSStella Laurenzo   T *referrent;
91436c6c9cSStella Laurenzo   pybind11::object object;
92436c6c9cSStella Laurenzo };
93436c6c9cSStella Laurenzo 
94436c6c9cSStella Laurenzo /// Tracks an entry in the thread context stack. New entries are pushed onto
95436c6c9cSStella Laurenzo /// here for each with block that activates a new InsertionPoint, Context or
96436c6c9cSStella Laurenzo /// Location.
97436c6c9cSStella Laurenzo ///
98436c6c9cSStella Laurenzo /// Pushing either a Location or InsertionPoint also pushes its associated
99436c6c9cSStella Laurenzo /// Context. Pushing a Context will not modify the Location or InsertionPoint
100436c6c9cSStella Laurenzo /// unless if they are from a different context, in which case, they are
101436c6c9cSStella Laurenzo /// cleared.
102436c6c9cSStella Laurenzo class PyThreadContextEntry {
103436c6c9cSStella Laurenzo public:
104436c6c9cSStella Laurenzo   enum class FrameKind {
105436c6c9cSStella Laurenzo     Context,
106436c6c9cSStella Laurenzo     InsertionPoint,
107436c6c9cSStella Laurenzo     Location,
108436c6c9cSStella Laurenzo   };
109436c6c9cSStella Laurenzo 
PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)110436c6c9cSStella Laurenzo   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
111436c6c9cSStella Laurenzo                        pybind11::object insertionPoint,
112436c6c9cSStella Laurenzo                        pybind11::object location)
113436c6c9cSStella Laurenzo       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
114436c6c9cSStella Laurenzo         location(std::move(location)), frameKind(frameKind) {}
115436c6c9cSStella Laurenzo 
116436c6c9cSStella Laurenzo   /// Gets the top of stack context and return nullptr if not defined.
117436c6c9cSStella Laurenzo   static PyMlirContext *getDefaultContext();
118436c6c9cSStella Laurenzo 
119436c6c9cSStella Laurenzo   /// Gets the top of stack insertion point and return nullptr if not defined.
120436c6c9cSStella Laurenzo   static PyInsertionPoint *getDefaultInsertionPoint();
121436c6c9cSStella Laurenzo 
122436c6c9cSStella Laurenzo   /// Gets the top of stack location and returns nullptr if not defined.
123436c6c9cSStella Laurenzo   static PyLocation *getDefaultLocation();
124436c6c9cSStella Laurenzo 
125436c6c9cSStella Laurenzo   PyMlirContext *getContext();
126436c6c9cSStella Laurenzo   PyInsertionPoint *getInsertionPoint();
127436c6c9cSStella Laurenzo   PyLocation *getLocation();
getFrameKind()128436c6c9cSStella Laurenzo   FrameKind getFrameKind() { return frameKind; }
129436c6c9cSStella Laurenzo 
130436c6c9cSStella Laurenzo   /// Stack management.
131436c6c9cSStella Laurenzo   static PyThreadContextEntry *getTopOfStack();
132436c6c9cSStella Laurenzo   static pybind11::object pushContext(PyMlirContext &context);
133436c6c9cSStella Laurenzo   static void popContext(PyMlirContext &context);
134436c6c9cSStella Laurenzo   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
135436c6c9cSStella Laurenzo   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
136436c6c9cSStella Laurenzo   static pybind11::object pushLocation(PyLocation &location);
137436c6c9cSStella Laurenzo   static void popLocation(PyLocation &location);
138436c6c9cSStella Laurenzo 
139436c6c9cSStella Laurenzo   /// Gets the thread local stack.
140436c6c9cSStella Laurenzo   static std::vector<PyThreadContextEntry> &getStack();
141436c6c9cSStella Laurenzo 
142436c6c9cSStella Laurenzo private:
143436c6c9cSStella Laurenzo   static void push(FrameKind frameKind, pybind11::object context,
144436c6c9cSStella Laurenzo                    pybind11::object insertionPoint, pybind11::object location);
145436c6c9cSStella Laurenzo 
146436c6c9cSStella Laurenzo   /// An object reference to the PyContext.
147436c6c9cSStella Laurenzo   pybind11::object context;
148436c6c9cSStella Laurenzo   /// An object reference to the current insertion point.
149436c6c9cSStella Laurenzo   pybind11::object insertionPoint;
150436c6c9cSStella Laurenzo   /// An object reference to the current location.
151436c6c9cSStella Laurenzo   pybind11::object location;
152436c6c9cSStella Laurenzo   // The kind of push that was performed.
153436c6c9cSStella Laurenzo   FrameKind frameKind;
154436c6c9cSStella Laurenzo };
155436c6c9cSStella Laurenzo 
156436c6c9cSStella Laurenzo /// Wrapper around MlirContext.
157436c6c9cSStella Laurenzo using PyMlirContextRef = PyObjectRef<PyMlirContext>;
158436c6c9cSStella Laurenzo class PyMlirContext {
159436c6c9cSStella Laurenzo public:
160436c6c9cSStella Laurenzo   PyMlirContext() = delete;
161436c6c9cSStella Laurenzo   PyMlirContext(const PyMlirContext &) = delete;
162436c6c9cSStella Laurenzo   PyMlirContext(PyMlirContext &&) = delete;
163436c6c9cSStella Laurenzo 
164436c6c9cSStella Laurenzo   /// For the case of a python __init__ (py::init) method, pybind11 is quite
165436c6c9cSStella Laurenzo   /// strict about needing to return a pointer that is not yet associated to
166436c6c9cSStella Laurenzo   /// an py::object. Since the forContext() method acts like a pool, possibly
167436c6c9cSStella Laurenzo   /// returning a recycled context, it does not satisfy this need. The usual
168436c6c9cSStella Laurenzo   /// way in python to accomplish such a thing is to override __new__, but
169436c6c9cSStella Laurenzo   /// that is also not supported by pybind11. Instead, we use this entry
170436c6c9cSStella Laurenzo   /// point which always constructs a fresh context (which cannot alias an
171436c6c9cSStella Laurenzo   /// existing one because it is fresh).
172436c6c9cSStella Laurenzo   static PyMlirContext *createNewContextForInit();
173436c6c9cSStella Laurenzo 
174436c6c9cSStella Laurenzo   /// Returns a context reference for the singleton PyMlirContext wrapper for
175436c6c9cSStella Laurenzo   /// the given context.
176436c6c9cSStella Laurenzo   static PyMlirContextRef forContext(MlirContext context);
177436c6c9cSStella Laurenzo   ~PyMlirContext();
178436c6c9cSStella Laurenzo 
179436c6c9cSStella Laurenzo   /// Accesses the underlying MlirContext.
get()180436c6c9cSStella Laurenzo   MlirContext get() { return context; }
181436c6c9cSStella Laurenzo 
182436c6c9cSStella Laurenzo   /// Gets a strong reference to this context, which will ensure it is kept
183436c6c9cSStella Laurenzo   /// alive for the life of the reference.
getRef()184436c6c9cSStella Laurenzo   PyMlirContextRef getRef() {
185436c6c9cSStella Laurenzo     return PyMlirContextRef(this, pybind11::cast(this));
186436c6c9cSStella Laurenzo   }
187436c6c9cSStella Laurenzo 
188436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirContext.
189436c6c9cSStella Laurenzo   pybind11::object getCapsule();
190436c6c9cSStella Laurenzo 
191436c6c9cSStella Laurenzo   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
192436c6c9cSStella Laurenzo   /// Note that PyMlirContext instances are uniqued, so the returned object
193436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirContext
194436c6c9cSStella Laurenzo   /// is taken by calling this function.
195436c6c9cSStella Laurenzo   static pybind11::object createFromCapsule(pybind11::object capsule);
196436c6c9cSStella Laurenzo 
197436c6c9cSStella Laurenzo   /// Gets the count of live context objects. Used for testing.
198436c6c9cSStella Laurenzo   static size_t getLiveCount();
199436c6c9cSStella Laurenzo 
200436c6c9cSStella Laurenzo   /// Gets the count of live operations associated with this context.
201436c6c9cSStella Laurenzo   /// Used for testing.
202436c6c9cSStella Laurenzo   size_t getLiveOperationCount();
203436c6c9cSStella Laurenzo 
2046b0bed7eSJohn Demme   /// Clears the live operations map, returning the number of entries which were
2056b0bed7eSJohn Demme   /// invalidated. To be used as a safety mechanism so that API end-users can't
2066b0bed7eSJohn Demme   /// corrupt by holding references they shouldn't have accessed in the first
2076b0bed7eSJohn Demme   /// place.
2086b0bed7eSJohn Demme   size_t clearLiveOperations();
2096b0bed7eSJohn Demme 
210436c6c9cSStella Laurenzo   /// Gets the count of live modules associated with this context.
211436c6c9cSStella Laurenzo   /// Used for testing.
212436c6c9cSStella Laurenzo   size_t getLiveModuleCount();
213436c6c9cSStella Laurenzo 
214436c6c9cSStella Laurenzo   /// Enter and exit the context manager.
215436c6c9cSStella Laurenzo   pybind11::object contextEnter();
2161fc096afSMehdi Amini   void contextExit(const pybind11::object &excType,
2171fc096afSMehdi Amini                    const pybind11::object &excVal,
2181fc096afSMehdi Amini                    const pybind11::object &excTb);
219436c6c9cSStella Laurenzo 
2207ee25bc5SStella Laurenzo   /// Attaches a Python callback as a diagnostic handler, returning a
2217ee25bc5SStella Laurenzo   /// registration object (internally a PyDiagnosticHandler).
2227ee25bc5SStella Laurenzo   pybind11::object attachDiagnosticHandler(pybind11::object callback);
2237ee25bc5SStella Laurenzo 
224436c6c9cSStella Laurenzo private:
225436c6c9cSStella Laurenzo   PyMlirContext(MlirContext context);
226436c6c9cSStella Laurenzo   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
227436c6c9cSStella Laurenzo   // preserving the relationship that an MlirContext maps to a single
228436c6c9cSStella Laurenzo   // PyMlirContext wrapper. This could be replaced in the future with an
229436c6c9cSStella Laurenzo   // extension mechanism on the MlirContext for stashing user pointers.
230436c6c9cSStella Laurenzo   // Note that this holds a handle, which does not imply ownership.
231436c6c9cSStella Laurenzo   // Mappings will be removed when the context is destructed.
232436c6c9cSStella Laurenzo   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
233436c6c9cSStella Laurenzo   static LiveContextMap &getLiveContexts();
234436c6c9cSStella Laurenzo 
235436c6c9cSStella Laurenzo   // Interns all live modules associated with this context. Modules tracked
236436c6c9cSStella Laurenzo   // in this map are valid. When a module is invalidated, it is removed
237436c6c9cSStella Laurenzo   // from this map, and while it still exists as an instance, any
238436c6c9cSStella Laurenzo   // attempt to access it will raise an error.
239436c6c9cSStella Laurenzo   using LiveModuleMap =
240436c6c9cSStella Laurenzo       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
241436c6c9cSStella Laurenzo   LiveModuleMap liveModules;
242436c6c9cSStella Laurenzo 
243436c6c9cSStella Laurenzo   // Interns all live operations associated with this context. Operations
244436c6c9cSStella Laurenzo   // tracked in this map are valid. When an operation is invalidated, it is
245436c6c9cSStella Laurenzo   // removed from this map, and while it still exists as an instance, any
246436c6c9cSStella Laurenzo   // attempt to access it will raise an error.
247436c6c9cSStella Laurenzo   using LiveOperationMap =
248436c6c9cSStella Laurenzo       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
249436c6c9cSStella Laurenzo   LiveOperationMap liveOperations;
250436c6c9cSStella Laurenzo 
251436c6c9cSStella Laurenzo   MlirContext context;
252436c6c9cSStella Laurenzo   friend class PyModule;
253436c6c9cSStella Laurenzo   friend class PyOperation;
254436c6c9cSStella Laurenzo };
255436c6c9cSStella Laurenzo 
256436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context
257436c6c9cSStella Laurenzo /// manager set instance.
258436c6c9cSStella Laurenzo class DefaultingPyMlirContext
259436c6c9cSStella Laurenzo     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
260436c6c9cSStella Laurenzo public:
261436c6c9cSStella Laurenzo   using Defaulting::Defaulting;
262a6e7d024SStella Laurenzo   static constexpr const char kTypeDescription[] = "mlir.ir.Context";
263436c6c9cSStella Laurenzo   static PyMlirContext &resolve();
264436c6c9cSStella Laurenzo };
265436c6c9cSStella Laurenzo 
266436c6c9cSStella Laurenzo /// Base class for all objects that directly or indirectly depend on an
267436c6c9cSStella Laurenzo /// MlirContext. The lifetime of the context will extend at least to the
268436c6c9cSStella Laurenzo /// lifetime of these instances.
269436c6c9cSStella Laurenzo /// Immutable objects that depend on a context extend this directly.
270436c6c9cSStella Laurenzo class BaseContextObject {
271436c6c9cSStella Laurenzo public:
BaseContextObject(PyMlirContextRef ref)272436c6c9cSStella Laurenzo   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
273436c6c9cSStella Laurenzo     assert(this->contextRef &&
274436c6c9cSStella Laurenzo            "context object constructed with null context ref");
275436c6c9cSStella Laurenzo   }
276436c6c9cSStella Laurenzo 
277436c6c9cSStella Laurenzo   /// Accesses the context reference.
getContext()278436c6c9cSStella Laurenzo   PyMlirContextRef &getContext() { return contextRef; }
279436c6c9cSStella Laurenzo 
280436c6c9cSStella Laurenzo private:
281436c6c9cSStella Laurenzo   PyMlirContextRef contextRef;
282436c6c9cSStella Laurenzo };
283436c6c9cSStella Laurenzo 
2847ee25bc5SStella Laurenzo /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
2857ee25bc5SStella Laurenzo /// are only valid for the duration of a diagnostic callback and attempting
2867ee25bc5SStella Laurenzo /// to access them outside of that will raise an exception. This applies to
2877ee25bc5SStella Laurenzo /// nested diagnostics (in the notes) as well.
2887ee25bc5SStella Laurenzo class PyDiagnostic {
2897ee25bc5SStella Laurenzo public:
PyDiagnostic(MlirDiagnostic diagnostic)2907ee25bc5SStella Laurenzo   PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
2917ee25bc5SStella Laurenzo   void invalidate();
isValid()2927ee25bc5SStella Laurenzo   bool isValid() { return valid; }
2937ee25bc5SStella Laurenzo   MlirDiagnosticSeverity getSeverity();
2947ee25bc5SStella Laurenzo   PyLocation getLocation();
2957ee25bc5SStella Laurenzo   pybind11::str getMessage();
2967ee25bc5SStella Laurenzo   pybind11::tuple getNotes();
2977ee25bc5SStella Laurenzo 
2987ee25bc5SStella Laurenzo private:
2997ee25bc5SStella Laurenzo   MlirDiagnostic diagnostic;
3007ee25bc5SStella Laurenzo 
3017ee25bc5SStella Laurenzo   void checkValid();
3027ee25bc5SStella Laurenzo   /// If notes have been materialized from the diagnostic, then this will
3037ee25bc5SStella Laurenzo   /// be populated with the corresponding objects (all castable to
3047ee25bc5SStella Laurenzo   /// PyDiagnostic).
3057ee25bc5SStella Laurenzo   llvm::Optional<pybind11::tuple> materializedNotes;
3067ee25bc5SStella Laurenzo   bool valid = true;
3077ee25bc5SStella Laurenzo };
3087ee25bc5SStella Laurenzo 
3097ee25bc5SStella Laurenzo /// Represents a diagnostic handler attached to the context. The handler's
3107ee25bc5SStella Laurenzo /// callback will be invoked with PyDiagnostic instances until the detach()
3117ee25bc5SStella Laurenzo /// method is called or the context is destroyed. A diagnostic handler can be
3127ee25bc5SStella Laurenzo /// the subject of a `with` block, which will detach it when the block exits.
3137ee25bc5SStella Laurenzo ///
3147ee25bc5SStella Laurenzo /// Since diagnostic handlers can call back into Python code which can do
3157ee25bc5SStella Laurenzo /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
3167ee25bc5SStella Laurenzo /// etc), this is generally not deemed to be a great user-level API. Users
3177ee25bc5SStella Laurenzo /// should generally use some form of DiagnosticCollector. If the handler raises
3187ee25bc5SStella Laurenzo /// any exceptions, they will just be emitted to stderr and dropped.
3197ee25bc5SStella Laurenzo ///
3207ee25bc5SStella Laurenzo /// The unique usage of this class means that its lifetime management is
3217ee25bc5SStella Laurenzo /// different from most other parts of the API. Instances are always created
3227ee25bc5SStella Laurenzo /// in an attached state and can transition to a detached state by either:
3237ee25bc5SStella Laurenzo ///   a) The context being destroyed and unregistering all handlers.
3247ee25bc5SStella Laurenzo ///   b) An explicit call to detach().
3257ee25bc5SStella Laurenzo /// The object may remain live from a Python perspective for an arbitrary time
3267ee25bc5SStella Laurenzo /// after detachment, but there is nothing the user can do with it (since there
3277ee25bc5SStella Laurenzo /// is no way to attach an existing handler object).
3287ee25bc5SStella Laurenzo class PyDiagnosticHandler {
3297ee25bc5SStella Laurenzo public:
3307ee25bc5SStella Laurenzo   PyDiagnosticHandler(MlirContext context, pybind11::object callback);
3317ee25bc5SStella Laurenzo   ~PyDiagnosticHandler();
3327ee25bc5SStella Laurenzo 
isAttached()3337ee25bc5SStella Laurenzo   bool isAttached() { return registeredID.hasValue(); }
getHadError()3347ee25bc5SStella Laurenzo   bool getHadError() { return hadError; }
3357ee25bc5SStella Laurenzo 
3367ee25bc5SStella Laurenzo   /// Detaches the handler. Does nothing if not attached.
3377ee25bc5SStella Laurenzo   void detach();
3387ee25bc5SStella Laurenzo 
contextEnter()3397ee25bc5SStella Laurenzo   pybind11::object contextEnter() { return pybind11::cast(this); }
contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)340e8d07395SMehdi Amini   void contextExit(const pybind11::object &excType,
341e8d07395SMehdi Amini                    const pybind11::object &excVal,
342e8d07395SMehdi Amini                    const pybind11::object &excTb) {
3437ee25bc5SStella Laurenzo     detach();
3447ee25bc5SStella Laurenzo   }
3457ee25bc5SStella Laurenzo 
3467ee25bc5SStella Laurenzo private:
3477ee25bc5SStella Laurenzo   MlirContext context;
3487ee25bc5SStella Laurenzo   pybind11::object callback;
3497ee25bc5SStella Laurenzo   llvm::Optional<MlirDiagnosticHandlerID> registeredID;
3507ee25bc5SStella Laurenzo   bool hadError = false;
3517ee25bc5SStella Laurenzo   friend class PyMlirContext;
3527ee25bc5SStella Laurenzo };
3537ee25bc5SStella Laurenzo 
354436c6c9cSStella Laurenzo /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
355436c6c9cSStella Laurenzo /// order to differentiate it from the `Dialect` base class which is extended by
356436c6c9cSStella Laurenzo /// plugins which extend dialect functionality through extension python code.
357436c6c9cSStella Laurenzo /// This should be seen as the "low-level" object and `Dialect` as the
358436c6c9cSStella Laurenzo /// high-level, user facing object.
359436c6c9cSStella Laurenzo class PyDialectDescriptor : public BaseContextObject {
360436c6c9cSStella Laurenzo public:
PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)361436c6c9cSStella Laurenzo   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
362436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
363436c6c9cSStella Laurenzo 
get()364436c6c9cSStella Laurenzo   MlirDialect get() { return dialect; }
365436c6c9cSStella Laurenzo 
366436c6c9cSStella Laurenzo private:
367436c6c9cSStella Laurenzo   MlirDialect dialect;
368436c6c9cSStella Laurenzo };
369436c6c9cSStella Laurenzo 
370436c6c9cSStella Laurenzo /// User-level object for accessing dialects with dotted syntax such as:
371436c6c9cSStella Laurenzo ///   ctx.dialect.std
372436c6c9cSStella Laurenzo class PyDialects : public BaseContextObject {
373436c6c9cSStella Laurenzo public:
PyDialects(PyMlirContextRef contextRef)374436c6c9cSStella Laurenzo   PyDialects(PyMlirContextRef contextRef)
375436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)) {}
376436c6c9cSStella Laurenzo 
377436c6c9cSStella Laurenzo   MlirDialect getDialectForKey(const std::string &key, bool attrError);
378436c6c9cSStella Laurenzo };
379436c6c9cSStella Laurenzo 
380436c6c9cSStella Laurenzo /// User-level dialect object. For dialects that have a registered extension,
381436c6c9cSStella Laurenzo /// this will be the base class of the extension dialect type. For un-extended,
382436c6c9cSStella Laurenzo /// objects of this type will be returned directly.
383436c6c9cSStella Laurenzo class PyDialect {
384436c6c9cSStella Laurenzo public:
PyDialect(pybind11::object descriptor)385436c6c9cSStella Laurenzo   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
386436c6c9cSStella Laurenzo 
getDescriptor()387436c6c9cSStella Laurenzo   pybind11::object getDescriptor() { return descriptor; }
388436c6c9cSStella Laurenzo 
389436c6c9cSStella Laurenzo private:
390436c6c9cSStella Laurenzo   pybind11::object descriptor;
391436c6c9cSStella Laurenzo };
392436c6c9cSStella Laurenzo 
393*5e83a5b4SStella Laurenzo /// Wrapper around an MlirDialectRegistry.
394*5e83a5b4SStella Laurenzo /// Upon construction, the Python wrapper takes ownership of the
395*5e83a5b4SStella Laurenzo /// underlying MlirDialectRegistry.
396*5e83a5b4SStella Laurenzo class PyDialectRegistry {
397*5e83a5b4SStella Laurenzo public:
PyDialectRegistry()398*5e83a5b4SStella Laurenzo   PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {}
PyDialectRegistry(MlirDialectRegistry registry)399*5e83a5b4SStella Laurenzo   PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {}
~PyDialectRegistry()400*5e83a5b4SStella Laurenzo   ~PyDialectRegistry() {
401*5e83a5b4SStella Laurenzo     if (!mlirDialectRegistryIsNull(registry))
402*5e83a5b4SStella Laurenzo       mlirDialectRegistryDestroy(registry);
403*5e83a5b4SStella Laurenzo   }
404*5e83a5b4SStella Laurenzo   PyDialectRegistry(PyDialectRegistry &) = delete;
PyDialectRegistry(PyDialectRegistry && other)405*5e83a5b4SStella Laurenzo   PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) {
406*5e83a5b4SStella Laurenzo     other.registry = {nullptr};
407*5e83a5b4SStella Laurenzo   }
408*5e83a5b4SStella Laurenzo 
MlirDialectRegistry()409*5e83a5b4SStella Laurenzo   operator MlirDialectRegistry() const { return registry; }
get()410*5e83a5b4SStella Laurenzo   MlirDialectRegistry get() const { return registry; }
411*5e83a5b4SStella Laurenzo 
412*5e83a5b4SStella Laurenzo   pybind11::object getCapsule();
413*5e83a5b4SStella Laurenzo   static PyDialectRegistry createFromCapsule(pybind11::object capsule);
414*5e83a5b4SStella Laurenzo 
415*5e83a5b4SStella Laurenzo private:
416*5e83a5b4SStella Laurenzo   MlirDialectRegistry registry;
417*5e83a5b4SStella Laurenzo };
418*5e83a5b4SStella Laurenzo 
419436c6c9cSStella Laurenzo /// Wrapper around an MlirLocation.
420436c6c9cSStella Laurenzo class PyLocation : public BaseContextObject {
421436c6c9cSStella Laurenzo public:
PyLocation(PyMlirContextRef contextRef,MlirLocation loc)422436c6c9cSStella Laurenzo   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
423436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), loc(loc) {}
424436c6c9cSStella Laurenzo 
MlirLocation()425436c6c9cSStella Laurenzo   operator MlirLocation() const { return loc; }
get()426436c6c9cSStella Laurenzo   MlirLocation get() const { return loc; }
427436c6c9cSStella Laurenzo 
428436c6c9cSStella Laurenzo   /// Enter and exit the context manager.
429436c6c9cSStella Laurenzo   pybind11::object contextEnter();
4301fc096afSMehdi Amini   void contextExit(const pybind11::object &excType,
4311fc096afSMehdi Amini                    const pybind11::object &excVal,
4321fc096afSMehdi Amini                    const pybind11::object &excTb);
433436c6c9cSStella Laurenzo 
434436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirLocation.
435436c6c9cSStella Laurenzo   pybind11::object getCapsule();
436436c6c9cSStella Laurenzo 
437436c6c9cSStella Laurenzo   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
438436c6c9cSStella Laurenzo   /// Note that PyLocation instances are uniqued, so the returned object
439436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirLocation
440436c6c9cSStella Laurenzo   /// is taken by calling this function.
441436c6c9cSStella Laurenzo   static PyLocation createFromCapsule(pybind11::object capsule);
442436c6c9cSStella Laurenzo 
443436c6c9cSStella Laurenzo private:
444436c6c9cSStella Laurenzo   MlirLocation loc;
445436c6c9cSStella Laurenzo };
446436c6c9cSStella Laurenzo 
447436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context
448436c6c9cSStella Laurenzo /// manager set instance.
449436c6c9cSStella Laurenzo class DefaultingPyLocation
450436c6c9cSStella Laurenzo     : public Defaulting<DefaultingPyLocation, PyLocation> {
451436c6c9cSStella Laurenzo public:
452436c6c9cSStella Laurenzo   using Defaulting::Defaulting;
453a6e7d024SStella Laurenzo   static constexpr const char kTypeDescription[] = "mlir.ir.Location";
454436c6c9cSStella Laurenzo   static PyLocation &resolve();
455436c6c9cSStella Laurenzo 
MlirLocation()456436c6c9cSStella Laurenzo   operator MlirLocation() const { return *get(); }
457436c6c9cSStella Laurenzo };
458436c6c9cSStella Laurenzo 
459436c6c9cSStella Laurenzo /// Wrapper around MlirModule.
460436c6c9cSStella Laurenzo /// This is the top-level, user-owned object that contains regions/ops/blocks.
461436c6c9cSStella Laurenzo class PyModule;
462436c6c9cSStella Laurenzo using PyModuleRef = PyObjectRef<PyModule>;
463436c6c9cSStella Laurenzo class PyModule : public BaseContextObject {
464436c6c9cSStella Laurenzo public:
465436c6c9cSStella Laurenzo   /// Returns a PyModule reference for the given MlirModule. This may return
466436c6c9cSStella Laurenzo   /// a pre-existing or new object.
467436c6c9cSStella Laurenzo   static PyModuleRef forModule(MlirModule module);
468436c6c9cSStella Laurenzo   PyModule(PyModule &) = delete;
469436c6c9cSStella Laurenzo   PyModule(PyMlirContext &&) = delete;
470436c6c9cSStella Laurenzo   ~PyModule();
471436c6c9cSStella Laurenzo 
472436c6c9cSStella Laurenzo   /// Gets the backing MlirModule.
get()473436c6c9cSStella Laurenzo   MlirModule get() { return module; }
474436c6c9cSStella Laurenzo 
475436c6c9cSStella Laurenzo   /// Gets a strong reference to this module.
getRef()476436c6c9cSStella Laurenzo   PyModuleRef getRef() {
477436c6c9cSStella Laurenzo     return PyModuleRef(this,
478436c6c9cSStella Laurenzo                        pybind11::reinterpret_borrow<pybind11::object>(handle));
479436c6c9cSStella Laurenzo   }
480436c6c9cSStella Laurenzo 
481436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirModule.
482436c6c9cSStella Laurenzo   /// Note that the module does not (yet) provide a corresponding factory for
483436c6c9cSStella Laurenzo   /// constructing from a capsule as that would require uniquing PyModule
484436c6c9cSStella Laurenzo   /// instances, which is not currently done.
485436c6c9cSStella Laurenzo   pybind11::object getCapsule();
486436c6c9cSStella Laurenzo 
487436c6c9cSStella Laurenzo   /// Creates a PyModule from the MlirModule wrapped by a capsule.
488436c6c9cSStella Laurenzo   /// Note that PyModule instances are uniqued, so the returned object
489436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirModule
490436c6c9cSStella Laurenzo   /// is taken by calling this function.
491436c6c9cSStella Laurenzo   static pybind11::object createFromCapsule(pybind11::object capsule);
492436c6c9cSStella Laurenzo 
493436c6c9cSStella Laurenzo private:
494436c6c9cSStella Laurenzo   PyModule(PyMlirContextRef contextRef, MlirModule module);
495436c6c9cSStella Laurenzo   MlirModule module;
496436c6c9cSStella Laurenzo   pybind11::handle handle;
497436c6c9cSStella Laurenzo };
498436c6c9cSStella Laurenzo 
499436c6c9cSStella Laurenzo /// Base class for PyOperation and PyOpView which exposes the primary, user
500436c6c9cSStella Laurenzo /// visible methods for manipulating it.
501436c6c9cSStella Laurenzo class PyOperationBase {
502436c6c9cSStella Laurenzo public:
503436c6c9cSStella Laurenzo   virtual ~PyOperationBase() = default;
504436c6c9cSStella Laurenzo   /// Implements the bound 'print' method and helps with others.
505436c6c9cSStella Laurenzo   void print(pybind11::object fileObject, bool binary,
506436c6c9cSStella Laurenzo              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
507ace1d0adSStella Laurenzo              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
508ace1d0adSStella Laurenzo              bool assumeVerified);
509436c6c9cSStella Laurenzo   pybind11::object getAsm(bool binary,
510436c6c9cSStella Laurenzo                           llvm::Optional<int64_t> largeElementsLimit,
511436c6c9cSStella Laurenzo                           bool enableDebugInfo, bool prettyDebugInfo,
512ace1d0adSStella Laurenzo                           bool printGenericOpForm, bool useLocalScope,
513ace1d0adSStella Laurenzo                           bool assumeVerified);
514436c6c9cSStella Laurenzo 
51524685aaeSAlex Zinenko   /// Moves the operation before or after the other operation.
51624685aaeSAlex Zinenko   void moveAfter(PyOperationBase &other);
51724685aaeSAlex Zinenko   void moveBefore(PyOperationBase &other);
51824685aaeSAlex Zinenko 
519436c6c9cSStella Laurenzo   /// Each must provide access to the raw Operation.
520436c6c9cSStella Laurenzo   virtual PyOperation &getOperation() = 0;
521436c6c9cSStella Laurenzo };
522436c6c9cSStella Laurenzo 
523436c6c9cSStella Laurenzo /// Wrapper around PyOperation.
524436c6c9cSStella Laurenzo /// Operations exist in either an attached (dependent) or detached (top-level)
525436c6c9cSStella Laurenzo /// state. In the detached state (as on creation), an operation is owned by
526436c6c9cSStella Laurenzo /// the creator and its lifetime extends either until its reference count
527436c6c9cSStella Laurenzo /// drops to zero or it is attached to a parent, at which point its lifetime
528436c6c9cSStella Laurenzo /// is bounded by its top-level parent reference.
529436c6c9cSStella Laurenzo class PyOperation;
530436c6c9cSStella Laurenzo using PyOperationRef = PyObjectRef<PyOperation>;
531436c6c9cSStella Laurenzo class PyOperation : public PyOperationBase, public BaseContextObject {
532436c6c9cSStella Laurenzo public:
533bd87241cSMehdi Amini   ~PyOperation() override;
getOperation()534436c6c9cSStella Laurenzo   PyOperation &getOperation() override { return *this; }
535436c6c9cSStella Laurenzo 
536436c6c9cSStella Laurenzo   /// Returns a PyOperation for the given MlirOperation, optionally associating
537436c6c9cSStella Laurenzo   /// it with a parentKeepAlive.
538436c6c9cSStella Laurenzo   static PyOperationRef
539436c6c9cSStella Laurenzo   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
540436c6c9cSStella Laurenzo                pybind11::object parentKeepAlive = pybind11::object());
541436c6c9cSStella Laurenzo 
542436c6c9cSStella Laurenzo   /// Creates a detached operation. The operation must not be associated with
543436c6c9cSStella Laurenzo   /// any existing live operation.
544436c6c9cSStella Laurenzo   static PyOperationRef
545436c6c9cSStella Laurenzo   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
546436c6c9cSStella Laurenzo                  pybind11::object parentKeepAlive = pybind11::object());
547436c6c9cSStella Laurenzo 
54824685aaeSAlex Zinenko   /// Detaches the operation from its parent block and updates its state
54924685aaeSAlex Zinenko   /// accordingly.
detachFromParent()55024685aaeSAlex Zinenko   void detachFromParent() {
55124685aaeSAlex Zinenko     mlirOperationRemoveFromParent(getOperation());
55224685aaeSAlex Zinenko     setDetached();
55324685aaeSAlex Zinenko     parentKeepAlive = pybind11::object();
55424685aaeSAlex Zinenko   }
55524685aaeSAlex Zinenko 
556436c6c9cSStella Laurenzo   /// Gets the backing operation.
MlirOperation()557436c6c9cSStella Laurenzo   operator MlirOperation() const { return get(); }
get()558436c6c9cSStella Laurenzo   MlirOperation get() const {
559436c6c9cSStella Laurenzo     checkValid();
560436c6c9cSStella Laurenzo     return operation;
561436c6c9cSStella Laurenzo   }
562436c6c9cSStella Laurenzo 
getRef()563436c6c9cSStella Laurenzo   PyOperationRef getRef() {
564436c6c9cSStella Laurenzo     return PyOperationRef(
565436c6c9cSStella Laurenzo         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
566436c6c9cSStella Laurenzo   }
567436c6c9cSStella Laurenzo 
isAttached()568436c6c9cSStella Laurenzo   bool isAttached() { return attached; }
569e8d07395SMehdi Amini   void setAttached(const pybind11::object &parent = pybind11::object()) {
570436c6c9cSStella Laurenzo     assert(!attached && "operation already attached");
571436c6c9cSStella Laurenzo     attached = true;
572436c6c9cSStella Laurenzo   }
setDetached()57324685aaeSAlex Zinenko   void setDetached() {
57424685aaeSAlex Zinenko     assert(attached && "operation already detached");
57524685aaeSAlex Zinenko     attached = false;
57624685aaeSAlex Zinenko   }
577436c6c9cSStella Laurenzo   void checkValid() const;
578436c6c9cSStella Laurenzo 
579436c6c9cSStella Laurenzo   /// Gets the owning block or raises an exception if the operation has no
580436c6c9cSStella Laurenzo   /// owning block.
581436c6c9cSStella Laurenzo   PyBlock getBlock();
582436c6c9cSStella Laurenzo 
583436c6c9cSStella Laurenzo   /// Gets the parent operation or raises an exception if the operation has
584436c6c9cSStella Laurenzo   /// no parent.
5851689dadeSJohn Demme   llvm::Optional<PyOperationRef> getParentOperation();
586436c6c9cSStella Laurenzo 
5870126e906SJohn Demme   /// Gets a capsule wrapping the void* within the MlirOperation.
5880126e906SJohn Demme   pybind11::object getCapsule();
5890126e906SJohn Demme 
5900126e906SJohn Demme   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
5910126e906SJohn Demme   /// Ownership of the underlying MlirOperation is taken by calling this
5920126e906SJohn Demme   /// function.
5930126e906SJohn Demme   static pybind11::object createFromCapsule(pybind11::object capsule);
5940126e906SJohn Demme 
595436c6c9cSStella Laurenzo   /// Creates an operation. See corresponding python docstring.
596436c6c9cSStella Laurenzo   static pybind11::object
5971fc096afSMehdi Amini   create(const std::string &name, llvm::Optional<std::vector<PyType *>> results,
598436c6c9cSStella Laurenzo          llvm::Optional<std::vector<PyValue *>> operands,
599436c6c9cSStella Laurenzo          llvm::Optional<pybind11::dict> attributes,
600436c6c9cSStella Laurenzo          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
6011fc096afSMehdi Amini          DefaultingPyLocation location, const pybind11::object &ip);
602436c6c9cSStella Laurenzo 
603436c6c9cSStella Laurenzo   /// Creates an OpView suitable for this operation.
604436c6c9cSStella Laurenzo   pybind11::object createOpView();
605436c6c9cSStella Laurenzo 
60649745f87SMike Urbach   /// Erases the underlying MlirOperation, removes its pointer from the
60749745f87SMike Urbach   /// parent context's live operations map, and sets the valid bit false.
60849745f87SMike Urbach   void erase();
60949745f87SMike Urbach 
6106b0bed7eSJohn Demme   /// Invalidate the operation.
setInvalid()6116b0bed7eSJohn Demme   void setInvalid() { valid = false; }
6126b0bed7eSJohn Demme 
613774818c0SDominik Grewe   /// Clones this operation.
614774818c0SDominik Grewe   pybind11::object clone(const pybind11::object &ip);
615774818c0SDominik Grewe 
616436c6c9cSStella Laurenzo private:
617436c6c9cSStella Laurenzo   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
618436c6c9cSStella Laurenzo   static PyOperationRef createInstance(PyMlirContextRef contextRef,
619436c6c9cSStella Laurenzo                                        MlirOperation operation,
620436c6c9cSStella Laurenzo                                        pybind11::object parentKeepAlive);
621436c6c9cSStella Laurenzo 
622436c6c9cSStella Laurenzo   MlirOperation operation;
623436c6c9cSStella Laurenzo   pybind11::handle handle;
624436c6c9cSStella Laurenzo   // Keeps the parent alive, regardless of whether it is an Operation or
625436c6c9cSStella Laurenzo   // Module.
626436c6c9cSStella Laurenzo   // TODO: As implemented, this facility is only sufficient for modeling the
627436c6c9cSStella Laurenzo   // trivial module parent back-reference. Generalize this to also account for
628436c6c9cSStella Laurenzo   // transitions from detached to attached and address TODOs in the
629436c6c9cSStella Laurenzo   // ir_operation.py regarding testing corresponding lifetime guarantees.
630436c6c9cSStella Laurenzo   pybind11::object parentKeepAlive;
631436c6c9cSStella Laurenzo   bool attached = true;
632436c6c9cSStella Laurenzo   bool valid = true;
63324685aaeSAlex Zinenko 
63424685aaeSAlex Zinenko   friend class PyOperationBase;
63530d61893SAlex Zinenko   friend class PySymbolTable;
636436c6c9cSStella Laurenzo };
637436c6c9cSStella Laurenzo 
638436c6c9cSStella Laurenzo /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
639436c6c9cSStella Laurenzo /// providing more instance-specific accessors and serve as the base class for
640436c6c9cSStella Laurenzo /// custom ODS-style operation classes. Since this class is subclass on the
641436c6c9cSStella Laurenzo /// python side, it must present an __init__ method that operates in pure
642436c6c9cSStella Laurenzo /// python types.
643436c6c9cSStella Laurenzo class PyOpView : public PyOperationBase {
644436c6c9cSStella Laurenzo public:
6451fc096afSMehdi Amini   PyOpView(const pybind11::object &operationObject);
getOperation()646436c6c9cSStella Laurenzo   PyOperation &getOperation() override { return operation; }
647436c6c9cSStella Laurenzo 
6481fc096afSMehdi Amini   static pybind11::object createRawSubclass(const pybind11::object &userClass);
649436c6c9cSStella Laurenzo 
getOperationObject()650436c6c9cSStella Laurenzo   pybind11::object getOperationObject() { return operationObject; }
651436c6c9cSStella Laurenzo 
652436c6c9cSStella Laurenzo   static pybind11::object
6531fc096afSMehdi Amini   buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
654436c6c9cSStella Laurenzo                pybind11::list operandList,
655436c6c9cSStella Laurenzo                llvm::Optional<pybind11::dict> attributes,
656436c6c9cSStella Laurenzo                llvm::Optional<std::vector<PyBlock *>> successors,
657436c6c9cSStella Laurenzo                llvm::Optional<int> regions, DefaultingPyLocation location,
6584f415216SMehdi Amini                const pybind11::object &maybeIp);
659436c6c9cSStella Laurenzo 
660436c6c9cSStella Laurenzo private:
661436c6c9cSStella Laurenzo   PyOperation &operation;           // For efficient, cast-free access from C++
662436c6c9cSStella Laurenzo   pybind11::object operationObject; // Holds the reference.
663436c6c9cSStella Laurenzo };
664436c6c9cSStella Laurenzo 
665436c6c9cSStella Laurenzo /// Wrapper around an MlirRegion.
666436c6c9cSStella Laurenzo /// Regions are managed completely by their containing operation. Unlike the
667436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached regions.
668436c6c9cSStella Laurenzo class PyRegion {
669436c6c9cSStella Laurenzo public:
PyRegion(PyOperationRef parentOperation,MlirRegion region)670436c6c9cSStella Laurenzo   PyRegion(PyOperationRef parentOperation, MlirRegion region)
671436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), region(region) {
672436c6c9cSStella Laurenzo     assert(!mlirRegionIsNull(region) && "python region cannot be null");
673436c6c9cSStella Laurenzo   }
MlirRegion()67478f2dae0SAlex Zinenko   operator MlirRegion() const { return region; }
675436c6c9cSStella Laurenzo 
get()676436c6c9cSStella Laurenzo   MlirRegion get() { return region; }
getParentOperation()677436c6c9cSStella Laurenzo   PyOperationRef &getParentOperation() { return parentOperation; }
678436c6c9cSStella Laurenzo 
checkValid()679436c6c9cSStella Laurenzo   void checkValid() { return parentOperation->checkValid(); }
680436c6c9cSStella Laurenzo 
681436c6c9cSStella Laurenzo private:
682436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
683436c6c9cSStella Laurenzo   MlirRegion region;
684436c6c9cSStella Laurenzo };
685436c6c9cSStella Laurenzo 
686436c6c9cSStella Laurenzo /// Wrapper around an MlirBlock.
687436c6c9cSStella Laurenzo /// Blocks are managed completely by their containing operation. Unlike the
688436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached blocks.
689436c6c9cSStella Laurenzo class PyBlock {
690436c6c9cSStella Laurenzo public:
PyBlock(PyOperationRef parentOperation,MlirBlock block)691436c6c9cSStella Laurenzo   PyBlock(PyOperationRef parentOperation, MlirBlock block)
692436c6c9cSStella Laurenzo       : parentOperation(std::move(parentOperation)), block(block) {
693436c6c9cSStella Laurenzo     assert(!mlirBlockIsNull(block) && "python block cannot be null");
694436c6c9cSStella Laurenzo   }
695436c6c9cSStella Laurenzo 
get()696436c6c9cSStella Laurenzo   MlirBlock get() { return block; }
getParentOperation()697436c6c9cSStella Laurenzo   PyOperationRef &getParentOperation() { return parentOperation; }
698436c6c9cSStella Laurenzo 
checkValid()699436c6c9cSStella Laurenzo   void checkValid() { return parentOperation->checkValid(); }
700436c6c9cSStella Laurenzo 
701436c6c9cSStella Laurenzo private:
702436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
703436c6c9cSStella Laurenzo   MlirBlock block;
704436c6c9cSStella Laurenzo };
705436c6c9cSStella Laurenzo 
706436c6c9cSStella Laurenzo /// An insertion point maintains a pointer to a Block and a reference operation.
707436c6c9cSStella Laurenzo /// Calls to insert() will insert a new operation before the
708436c6c9cSStella Laurenzo /// reference operation. If the reference operation is null, then appends to
709436c6c9cSStella Laurenzo /// the end of the block.
710436c6c9cSStella Laurenzo class PyInsertionPoint {
711436c6c9cSStella Laurenzo public:
712436c6c9cSStella Laurenzo   /// Creates an insertion point positioned after the last operation in the
713436c6c9cSStella Laurenzo   /// block, but still inside the block.
714436c6c9cSStella Laurenzo   PyInsertionPoint(PyBlock &block);
715436c6c9cSStella Laurenzo   /// Creates an insertion point positioned before a reference operation.
716436c6c9cSStella Laurenzo   PyInsertionPoint(PyOperationBase &beforeOperationBase);
717436c6c9cSStella Laurenzo 
718436c6c9cSStella Laurenzo   /// Shortcut to create an insertion point at the beginning of the block.
719436c6c9cSStella Laurenzo   static PyInsertionPoint atBlockBegin(PyBlock &block);
720436c6c9cSStella Laurenzo   /// Shortcut to create an insertion point before the block terminator.
721436c6c9cSStella Laurenzo   static PyInsertionPoint atBlockTerminator(PyBlock &block);
722436c6c9cSStella Laurenzo 
723436c6c9cSStella Laurenzo   /// Inserts an operation.
724436c6c9cSStella Laurenzo   void insert(PyOperationBase &operationBase);
725436c6c9cSStella Laurenzo 
726436c6c9cSStella Laurenzo   /// Enter and exit the context manager.
727436c6c9cSStella Laurenzo   pybind11::object contextEnter();
7281fc096afSMehdi Amini   void contextExit(const pybind11::object &excType,
7291fc096afSMehdi Amini                    const pybind11::object &excVal,
7301fc096afSMehdi Amini                    const pybind11::object &excTb);
731436c6c9cSStella Laurenzo 
getBlock()732436c6c9cSStella Laurenzo   PyBlock &getBlock() { return block; }
733436c6c9cSStella Laurenzo 
734436c6c9cSStella Laurenzo private:
735436c6c9cSStella Laurenzo   // Trampoline constructor that avoids null initializing members while
736436c6c9cSStella Laurenzo   // looking up parents.
PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)737436c6c9cSStella Laurenzo   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
738436c6c9cSStella Laurenzo       : refOperation(std::move(refOperation)), block(std::move(block)) {}
739436c6c9cSStella Laurenzo 
740436c6c9cSStella Laurenzo   llvm::Optional<PyOperationRef> refOperation;
741436c6c9cSStella Laurenzo   PyBlock block;
742436c6c9cSStella Laurenzo };
7432995d29bSAlex Zinenko /// Wrapper around the generic MlirType.
7442995d29bSAlex Zinenko /// The lifetime of a type is bound by the PyContext that created it.
7452995d29bSAlex Zinenko class PyType : public BaseContextObject {
7462995d29bSAlex Zinenko public:
PyType(PyMlirContextRef contextRef,MlirType type)7472995d29bSAlex Zinenko   PyType(PyMlirContextRef contextRef, MlirType type)
7482995d29bSAlex Zinenko       : BaseContextObject(std::move(contextRef)), type(type) {}
7492995d29bSAlex Zinenko   bool operator==(const PyType &other);
MlirType()7502995d29bSAlex Zinenko   operator MlirType() const { return type; }
get()7512995d29bSAlex Zinenko   MlirType get() const { return type; }
7522995d29bSAlex Zinenko 
7532995d29bSAlex Zinenko   /// Gets a capsule wrapping the void* within the MlirType.
7542995d29bSAlex Zinenko   pybind11::object getCapsule();
7552995d29bSAlex Zinenko 
7562995d29bSAlex Zinenko   /// Creates a PyType from the MlirType wrapped by a capsule.
7572995d29bSAlex Zinenko   /// Note that PyType instances are uniqued, so the returned object
7582995d29bSAlex Zinenko   /// may be a pre-existing object. Ownership of the underlying MlirType
7592995d29bSAlex Zinenko   /// is taken by calling this function.
7602995d29bSAlex Zinenko   static PyType createFromCapsule(pybind11::object capsule);
7612995d29bSAlex Zinenko 
7622995d29bSAlex Zinenko private:
7632995d29bSAlex Zinenko   MlirType type;
7642995d29bSAlex Zinenko };
7652995d29bSAlex Zinenko 
7662995d29bSAlex Zinenko /// CRTP base classes for Python types that subclass Type and should be
7672995d29bSAlex Zinenko /// castable from it (i.e. via something like IntegerType(t)).
7682995d29bSAlex Zinenko /// By default, type class hierarchies are one level deep (i.e. a
7692995d29bSAlex Zinenko /// concrete type class extends PyType); however, intermediate python-visible
7702995d29bSAlex Zinenko /// base classes can be modeled by specifying a BaseTy.
7712995d29bSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyType>
7722995d29bSAlex Zinenko class PyConcreteType : public BaseTy {
7732995d29bSAlex Zinenko public:
7742995d29bSAlex Zinenko   // Derived classes must define statics for:
7752995d29bSAlex Zinenko   //   IsAFunctionTy isaFunction
7762995d29bSAlex Zinenko   //   const char *pyClassName
7772995d29bSAlex Zinenko   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
7782995d29bSAlex Zinenko   using IsAFunctionTy = bool (*)(MlirType);
7792995d29bSAlex Zinenko 
7802995d29bSAlex Zinenko   PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)7812995d29bSAlex Zinenko   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
7822995d29bSAlex Zinenko       : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)7832995d29bSAlex Zinenko   PyConcreteType(PyType &orig)
7842995d29bSAlex Zinenko       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
7852995d29bSAlex Zinenko 
castFrom(PyType & orig)7862995d29bSAlex Zinenko   static MlirType castFrom(PyType &orig) {
7872995d29bSAlex Zinenko     if (!DerivedTy::isaFunction(orig)) {
7882995d29bSAlex Zinenko       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
7892995d29bSAlex Zinenko       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
7902995d29bSAlex Zinenko                                              DerivedTy::pyClassName +
7912995d29bSAlex Zinenko                                              " (from " + origRepr + ")");
7922995d29bSAlex Zinenko     }
7932995d29bSAlex Zinenko     return orig;
7942995d29bSAlex Zinenko   }
7952995d29bSAlex Zinenko 
bind(pybind11::module & m)7962995d29bSAlex Zinenko   static void bind(pybind11::module &m) {
7972995d29bSAlex Zinenko     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
798a6e7d024SStella Laurenzo     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
799a6e7d024SStella Laurenzo             pybind11::arg("cast_from_type"));
800a6e7d024SStella Laurenzo     cls.def_static(
801a6e7d024SStella Laurenzo         "isinstance",
802a6e7d024SStella Laurenzo         [](PyType &otherType) -> bool {
8032995d29bSAlex Zinenko           return DerivedTy::isaFunction(otherType);
804a6e7d024SStella Laurenzo         },
805a6e7d024SStella Laurenzo         pybind11::arg("other"));
8062995d29bSAlex Zinenko     DerivedTy::bindDerived(cls);
8072995d29bSAlex Zinenko   }
8082995d29bSAlex Zinenko 
8092995d29bSAlex Zinenko   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)8102995d29bSAlex Zinenko   static void bindDerived(ClassTy &m) {}
8112995d29bSAlex Zinenko };
812436c6c9cSStella Laurenzo 
813436c6c9cSStella Laurenzo /// Wrapper around the generic MlirAttribute.
814436c6c9cSStella Laurenzo /// The lifetime of a type is bound by the PyContext that created it.
815436c6c9cSStella Laurenzo class PyAttribute : public BaseContextObject {
816436c6c9cSStella Laurenzo public:
PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)817436c6c9cSStella Laurenzo   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
818436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), attr(attr) {}
819436c6c9cSStella Laurenzo   bool operator==(const PyAttribute &other);
MlirAttribute()820436c6c9cSStella Laurenzo   operator MlirAttribute() const { return attr; }
get()821436c6c9cSStella Laurenzo   MlirAttribute get() const { return attr; }
822436c6c9cSStella Laurenzo 
823436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirAttribute.
824436c6c9cSStella Laurenzo   pybind11::object getCapsule();
825436c6c9cSStella Laurenzo 
826436c6c9cSStella Laurenzo   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
827436c6c9cSStella Laurenzo   /// Note that PyAttribute instances are uniqued, so the returned object
828436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
829436c6c9cSStella Laurenzo   /// is taken by calling this function.
830436c6c9cSStella Laurenzo   static PyAttribute createFromCapsule(pybind11::object capsule);
831436c6c9cSStella Laurenzo 
832436c6c9cSStella Laurenzo private:
833436c6c9cSStella Laurenzo   MlirAttribute attr;
834436c6c9cSStella Laurenzo };
835436c6c9cSStella Laurenzo 
836436c6c9cSStella Laurenzo /// Represents a Python MlirNamedAttr, carrying an optional owned name.
837436c6c9cSStella Laurenzo /// TODO: Refactor this and the C-API to be based on an Identifier owned
838436c6c9cSStella Laurenzo /// by the context so as to avoid ownership issues here.
839436c6c9cSStella Laurenzo class PyNamedAttribute {
840436c6c9cSStella Laurenzo public:
841436c6c9cSStella Laurenzo   /// Constructs a PyNamedAttr that retains an owned name. This should be
842436c6c9cSStella Laurenzo   /// used in any code that originates an MlirNamedAttribute from a python
843436c6c9cSStella Laurenzo   /// string.
844436c6c9cSStella Laurenzo   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
845436c6c9cSStella Laurenzo   /// passed attribute.
846436c6c9cSStella Laurenzo   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
847436c6c9cSStella Laurenzo 
848436c6c9cSStella Laurenzo   MlirNamedAttribute namedAttr;
849436c6c9cSStella Laurenzo 
850436c6c9cSStella Laurenzo private:
851436c6c9cSStella Laurenzo   // Since the MlirNamedAttr contains an internal pointer to the actual
852436c6c9cSStella Laurenzo   // memory of the owned string, it must be heap allocated to remain valid.
853436c6c9cSStella Laurenzo   // Otherwise, strings that fit within the small object optimization threshold
854436c6c9cSStella Laurenzo   // will have their memory address change as the containing object is moved,
855436c6c9cSStella Laurenzo   // resulting in an invalid aliased pointer.
856436c6c9cSStella Laurenzo   std::unique_ptr<std::string> ownedName;
857436c6c9cSStella Laurenzo };
858436c6c9cSStella Laurenzo 
8590b10fdedSAlex Zinenko /// CRTP base classes for Python attributes that subclass Attribute and should
8600b10fdedSAlex Zinenko /// be castable from it (i.e. via something like StringAttr(attr)).
8610b10fdedSAlex Zinenko /// By default, attribute class hierarchies are one level deep (i.e. a
8620b10fdedSAlex Zinenko /// concrete attribute class extends PyAttribute); however, intermediate
8630b10fdedSAlex Zinenko /// python-visible base classes can be modeled by specifying a BaseTy.
8640b10fdedSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyAttribute>
8650b10fdedSAlex Zinenko class PyConcreteAttribute : public BaseTy {
8660b10fdedSAlex Zinenko public:
8670b10fdedSAlex Zinenko   // Derived classes must define statics for:
8680b10fdedSAlex Zinenko   //   IsAFunctionTy isaFunction
8690b10fdedSAlex Zinenko   //   const char *pyClassName
8700b10fdedSAlex Zinenko   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
8710b10fdedSAlex Zinenko   using IsAFunctionTy = bool (*)(MlirAttribute);
8720b10fdedSAlex Zinenko 
8730b10fdedSAlex Zinenko   PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)8740b10fdedSAlex Zinenko   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
8750b10fdedSAlex Zinenko       : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)8760b10fdedSAlex Zinenko   PyConcreteAttribute(PyAttribute &orig)
8770b10fdedSAlex Zinenko       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
8780b10fdedSAlex Zinenko 
castFrom(PyAttribute & orig)8790b10fdedSAlex Zinenko   static MlirAttribute castFrom(PyAttribute &orig) {
8800b10fdedSAlex Zinenko     if (!DerivedTy::isaFunction(orig)) {
8810b10fdedSAlex Zinenko       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
8820b10fdedSAlex Zinenko       throw SetPyError(PyExc_ValueError,
8830b10fdedSAlex Zinenko                        llvm::Twine("Cannot cast attribute to ") +
8840b10fdedSAlex Zinenko                            DerivedTy::pyClassName + " (from " + origRepr + ")");
8850b10fdedSAlex Zinenko     }
8860b10fdedSAlex Zinenko     return orig;
8870b10fdedSAlex Zinenko   }
8880b10fdedSAlex Zinenko 
bind(pybind11::module & m)8890b10fdedSAlex Zinenko   static void bind(pybind11::module &m) {
8908dca953dSSean Silva     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
8918dca953dSSean Silva                        pybind11::module_local());
892a6e7d024SStella Laurenzo     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
893a6e7d024SStella Laurenzo             pybind11::arg("cast_from_attr"));
894a6e7d024SStella Laurenzo     cls.def_static(
895a6e7d024SStella Laurenzo         "isinstance",
896a6e7d024SStella Laurenzo         [](PyAttribute &otherAttr) -> bool {
89778f2dae0SAlex Zinenko           return DerivedTy::isaFunction(otherAttr);
898a6e7d024SStella Laurenzo         },
899a6e7d024SStella Laurenzo         pybind11::arg("other"));
9002995d29bSAlex Zinenko     cls.def_property_readonly("type", [](PyAttribute &attr) {
9012995d29bSAlex Zinenko       return PyType(attr.getContext(), mlirAttributeGetType(attr));
90232e2fec7SJohn Demme     });
90332e2fec7SJohn Demme     DerivedTy::bindDerived(cls);
90432e2fec7SJohn Demme   }
90532e2fec7SJohn Demme 
90632e2fec7SJohn Demme   /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)90732e2fec7SJohn Demme   static void bindDerived(ClassTy &m) {}
90832e2fec7SJohn Demme };
90932e2fec7SJohn Demme 
910436c6c9cSStella Laurenzo /// Wrapper around the generic MlirValue.
911436c6c9cSStella Laurenzo /// Values are managed completely by the operation that resulted in their
912436c6c9cSStella Laurenzo /// definition. For op result value, this is the operation that defines the
913436c6c9cSStella Laurenzo /// value. For block argument values, this is the operation that contains the
914436c6c9cSStella Laurenzo /// block to which the value is an argument (blocks cannot be detached in Python
915436c6c9cSStella Laurenzo /// bindings so such operation always exists).
916436c6c9cSStella Laurenzo class PyValue {
917436c6c9cSStella Laurenzo public:
PyValue(PyOperationRef parentOperation,MlirValue value)918436c6c9cSStella Laurenzo   PyValue(PyOperationRef parentOperation, MlirValue value)
919e8d07395SMehdi Amini       : parentOperation(std::move(parentOperation)), value(value) {}
MlirValue()92078f2dae0SAlex Zinenko   operator MlirValue() const { return value; }
921436c6c9cSStella Laurenzo 
get()922436c6c9cSStella Laurenzo   MlirValue get() { return value; }
getParentOperation()923436c6c9cSStella Laurenzo   PyOperationRef &getParentOperation() { return parentOperation; }
924436c6c9cSStella Laurenzo 
checkValid()925436c6c9cSStella Laurenzo   void checkValid() { return parentOperation->checkValid(); }
926436c6c9cSStella Laurenzo 
9273f3d1c90SMike Urbach   /// Gets a capsule wrapping the void* within the MlirValue.
9283f3d1c90SMike Urbach   pybind11::object getCapsule();
9293f3d1c90SMike Urbach 
9303f3d1c90SMike Urbach   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
9313f3d1c90SMike Urbach   /// the underlying MlirValue is still tied to the owning operation.
9323f3d1c90SMike Urbach   static PyValue createFromCapsule(pybind11::object capsule);
9333f3d1c90SMike Urbach 
934436c6c9cSStella Laurenzo private:
935436c6c9cSStella Laurenzo   PyOperationRef parentOperation;
936436c6c9cSStella Laurenzo   MlirValue value;
937436c6c9cSStella Laurenzo };
938436c6c9cSStella Laurenzo 
939436c6c9cSStella Laurenzo /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
940436c6c9cSStella Laurenzo class PyAffineExpr : public BaseContextObject {
941436c6c9cSStella Laurenzo public:
PyAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)942436c6c9cSStella Laurenzo   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
943436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
944436c6c9cSStella Laurenzo   bool operator==(const PyAffineExpr &other);
MlirAffineExpr()945436c6c9cSStella Laurenzo   operator MlirAffineExpr() const { return affineExpr; }
get()946436c6c9cSStella Laurenzo   MlirAffineExpr get() const { return affineExpr; }
947436c6c9cSStella Laurenzo 
948436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
949436c6c9cSStella Laurenzo   pybind11::object getCapsule();
950436c6c9cSStella Laurenzo 
951436c6c9cSStella Laurenzo   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
952436c6c9cSStella Laurenzo   /// Note that PyAffineExpr instances are uniqued, so the returned object
953436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
954436c6c9cSStella Laurenzo   /// is taken by calling this function.
955436c6c9cSStella Laurenzo   static PyAffineExpr createFromCapsule(pybind11::object capsule);
956436c6c9cSStella Laurenzo 
957436c6c9cSStella Laurenzo   PyAffineExpr add(const PyAffineExpr &other) const;
958436c6c9cSStella Laurenzo   PyAffineExpr mul(const PyAffineExpr &other) const;
959436c6c9cSStella Laurenzo   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
960436c6c9cSStella Laurenzo   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
961436c6c9cSStella Laurenzo   PyAffineExpr mod(const PyAffineExpr &other) const;
962436c6c9cSStella Laurenzo 
963436c6c9cSStella Laurenzo private:
964436c6c9cSStella Laurenzo   MlirAffineExpr affineExpr;
965436c6c9cSStella Laurenzo };
966436c6c9cSStella Laurenzo 
967436c6c9cSStella Laurenzo class PyAffineMap : public BaseContextObject {
968436c6c9cSStella Laurenzo public:
PyAffineMap(PyMlirContextRef contextRef,MlirAffineMap affineMap)969436c6c9cSStella Laurenzo   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
970436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
971436c6c9cSStella Laurenzo   bool operator==(const PyAffineMap &other);
MlirAffineMap()972436c6c9cSStella Laurenzo   operator MlirAffineMap() const { return affineMap; }
get()973436c6c9cSStella Laurenzo   MlirAffineMap get() const { return affineMap; }
974436c6c9cSStella Laurenzo 
975436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirAffineMap.
976436c6c9cSStella Laurenzo   pybind11::object getCapsule();
977436c6c9cSStella Laurenzo 
978436c6c9cSStella Laurenzo   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
979436c6c9cSStella Laurenzo   /// Note that PyAffineMap instances are uniqued, so the returned object
980436c6c9cSStella Laurenzo   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
981436c6c9cSStella Laurenzo   /// is taken by calling this function.
982436c6c9cSStella Laurenzo   static PyAffineMap createFromCapsule(pybind11::object capsule);
983436c6c9cSStella Laurenzo 
984436c6c9cSStella Laurenzo private:
985436c6c9cSStella Laurenzo   MlirAffineMap affineMap;
986436c6c9cSStella Laurenzo };
987436c6c9cSStella Laurenzo 
988436c6c9cSStella Laurenzo class PyIntegerSet : public BaseContextObject {
989436c6c9cSStella Laurenzo public:
PyIntegerSet(PyMlirContextRef contextRef,MlirIntegerSet integerSet)990436c6c9cSStella Laurenzo   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
991436c6c9cSStella Laurenzo       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
992436c6c9cSStella Laurenzo   bool operator==(const PyIntegerSet &other);
MlirIntegerSet()993436c6c9cSStella Laurenzo   operator MlirIntegerSet() const { return integerSet; }
get()994436c6c9cSStella Laurenzo   MlirIntegerSet get() const { return integerSet; }
995436c6c9cSStella Laurenzo 
996436c6c9cSStella Laurenzo   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
997436c6c9cSStella Laurenzo   pybind11::object getCapsule();
998436c6c9cSStella Laurenzo 
999436c6c9cSStella Laurenzo   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
1000436c6c9cSStella Laurenzo   /// Note that PyIntegerSet instances may be uniqued, so the returned object
1001436c6c9cSStella Laurenzo   /// may be a pre-existing object. Integer sets are owned by the context.
1002436c6c9cSStella Laurenzo   static PyIntegerSet createFromCapsule(pybind11::object capsule);
1003436c6c9cSStella Laurenzo 
1004436c6c9cSStella Laurenzo private:
1005436c6c9cSStella Laurenzo   MlirIntegerSet integerSet;
1006436c6c9cSStella Laurenzo };
1007436c6c9cSStella Laurenzo 
100830d61893SAlex Zinenko /// Bindings for MLIR symbol tables.
100930d61893SAlex Zinenko class PySymbolTable {
101030d61893SAlex Zinenko public:
101130d61893SAlex Zinenko   /// Constructs a symbol table for the given operation.
101230d61893SAlex Zinenko   explicit PySymbolTable(PyOperationBase &operation);
101330d61893SAlex Zinenko 
101430d61893SAlex Zinenko   /// Destroys the symbol table.
~PySymbolTable()101530d61893SAlex Zinenko   ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
101630d61893SAlex Zinenko 
101730d61893SAlex Zinenko   /// Returns the symbol (opview) with the given name, throws if there is no
101830d61893SAlex Zinenko   /// such symbol in the table.
101930d61893SAlex Zinenko   pybind11::object dunderGetItem(const std::string &name);
102030d61893SAlex Zinenko 
102130d61893SAlex Zinenko   /// Removes the given operation from the symbol table and erases it.
102230d61893SAlex Zinenko   void erase(PyOperationBase &symbol);
102330d61893SAlex Zinenko 
102430d61893SAlex Zinenko   /// Removes the operation with the given name from the symbol table and erases
102530d61893SAlex Zinenko   /// it, throws if there is no such symbol in the table.
102630d61893SAlex Zinenko   void dunderDel(const std::string &name);
102730d61893SAlex Zinenko 
102830d61893SAlex Zinenko   /// Inserts the given operation into the symbol table. The operation must have
102930d61893SAlex Zinenko   /// the symbol trait.
103030d61893SAlex Zinenko   PyAttribute insert(PyOperationBase &symbol);
103130d61893SAlex Zinenko 
1032bdc31837SStella Laurenzo   /// Gets and sets the name of a symbol op.
1033bdc31837SStella Laurenzo   static PyAttribute getSymbolName(PyOperationBase &symbol);
1034bdc31837SStella Laurenzo   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
1035bdc31837SStella Laurenzo 
1036bdc31837SStella Laurenzo   /// Gets and sets the visibility of a symbol op.
1037bdc31837SStella Laurenzo   static PyAttribute getVisibility(PyOperationBase &symbol);
1038bdc31837SStella Laurenzo   static void setVisibility(PyOperationBase &symbol,
1039bdc31837SStella Laurenzo                             const std::string &visibility);
1040bdc31837SStella Laurenzo 
1041bdc31837SStella Laurenzo   /// Replaces all symbol uses within an operation. See the API
1042bdc31837SStella Laurenzo   /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
1043bdc31837SStella Laurenzo   static void replaceAllSymbolUses(const std::string &oldSymbol,
1044bdc31837SStella Laurenzo                                    const std::string &newSymbol,
1045bdc31837SStella Laurenzo                                    PyOperationBase &from);
1046bdc31837SStella Laurenzo 
1047bdc31837SStella Laurenzo   /// Walks all symbol tables under and including 'from'.
1048bdc31837SStella Laurenzo   static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
1049bdc31837SStella Laurenzo                                pybind11::object callback);
1050bdc31837SStella Laurenzo 
105130d61893SAlex Zinenko   /// Casts the bindings class into the C API structure.
MlirSymbolTable()105230d61893SAlex Zinenko   operator MlirSymbolTable() { return symbolTable; }
105330d61893SAlex Zinenko 
105430d61893SAlex Zinenko private:
105530d61893SAlex Zinenko   PyOperationRef operation;
105630d61893SAlex Zinenko   MlirSymbolTable symbolTable;
105730d61893SAlex Zinenko };
105830d61893SAlex Zinenko 
1059436c6c9cSStella Laurenzo void populateIRAffine(pybind11::module &m);
1060436c6c9cSStella Laurenzo void populateIRAttributes(pybind11::module &m);
1061436c6c9cSStella Laurenzo void populateIRCore(pybind11::module &m);
106214c92070SAlex Zinenko void populateIRInterfaces(pybind11::module &m);
1063436c6c9cSStella Laurenzo void populateIRTypes(pybind11::module &m);
1064436c6c9cSStella Laurenzo 
1065436c6c9cSStella Laurenzo } // namespace python
1066436c6c9cSStella Laurenzo } // namespace mlir
1067436c6c9cSStella Laurenzo 
1068436c6c9cSStella Laurenzo namespace pybind11 {
1069436c6c9cSStella Laurenzo namespace detail {
1070436c6c9cSStella Laurenzo 
1071436c6c9cSStella Laurenzo template <>
1072436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyMlirContext>
1073436c6c9cSStella Laurenzo     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
1074436c6c9cSStella Laurenzo template <>
1075436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyLocation>
1076436c6c9cSStella Laurenzo     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
1077436c6c9cSStella Laurenzo 
1078436c6c9cSStella Laurenzo } // namespace detail
1079436c6c9cSStella Laurenzo } // namespace pybind11
1080436c6c9cSStella Laurenzo 
1081436c6c9cSStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
1082