1436c6c9cSStella Laurenzo //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2436c6c9cSStella Laurenzo // 3436c6c9cSStella Laurenzo // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4436c6c9cSStella Laurenzo // See https://llvm.org/LICENSE.txt for license information. 5436c6c9cSStella Laurenzo // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6436c6c9cSStella Laurenzo // 7436c6c9cSStella Laurenzo //===----------------------------------------------------------------------===// 8436c6c9cSStella Laurenzo 9436c6c9cSStella Laurenzo #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10436c6c9cSStella Laurenzo #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11436c6c9cSStella Laurenzo 12e8d07395SMehdi Amini #include <utility> 13436c6c9cSStella Laurenzo #include <vector> 14436c6c9cSStella Laurenzo 15436c6c9cSStella Laurenzo #include "PybindUtils.h" 16436c6c9cSStella Laurenzo 17436c6c9cSStella Laurenzo #include "mlir-c/AffineExpr.h" 18436c6c9cSStella Laurenzo #include "mlir-c/AffineMap.h" 197ee25bc5SStella Laurenzo #include "mlir-c/Diagnostics.h" 20436c6c9cSStella Laurenzo #include "mlir-c/IR.h" 21436c6c9cSStella Laurenzo #include "mlir-c/IntegerSet.h" 22436c6c9cSStella Laurenzo #include "llvm/ADT/DenseMap.h" 231689dadeSJohn Demme #include "llvm/ADT/Optional.h" 24436c6c9cSStella Laurenzo 25436c6c9cSStella Laurenzo namespace mlir { 26436c6c9cSStella Laurenzo namespace python { 27436c6c9cSStella Laurenzo 28436c6c9cSStella Laurenzo class PyBlock; 297ee25bc5SStella Laurenzo class PyDiagnostic; 307ee25bc5SStella Laurenzo class PyDiagnosticHandler; 31436c6c9cSStella Laurenzo class PyInsertionPoint; 32436c6c9cSStella Laurenzo class PyLocation; 33436c6c9cSStella Laurenzo class DefaultingPyLocation; 34436c6c9cSStella Laurenzo class PyMlirContext; 35436c6c9cSStella Laurenzo class DefaultingPyMlirContext; 36436c6c9cSStella Laurenzo class PyModule; 37436c6c9cSStella Laurenzo class PyOperation; 38436c6c9cSStella Laurenzo class PyType; 3930d61893SAlex Zinenko class PySymbolTable; 40436c6c9cSStella Laurenzo class PyValue; 41436c6c9cSStella Laurenzo 42436c6c9cSStella Laurenzo /// Template for a reference to a concrete type which captures a python 43436c6c9cSStella Laurenzo /// reference to its underlying python object. 44436c6c9cSStella Laurenzo template <typename T> 45436c6c9cSStella Laurenzo class PyObjectRef { 46436c6c9cSStella Laurenzo public: PyObjectRef(T * referrent,pybind11::object object)47436c6c9cSStella Laurenzo PyObjectRef(T *referrent, pybind11::object object) 48436c6c9cSStella Laurenzo : referrent(referrent), object(std::move(object)) { 49436c6c9cSStella Laurenzo assert(this->referrent && 50436c6c9cSStella Laurenzo "cannot construct PyObjectRef with null referrent"); 51436c6c9cSStella Laurenzo assert(this->object && "cannot construct PyObjectRef with null object"); 52436c6c9cSStella Laurenzo } PyObjectRef(PyObjectRef && other)53436c6c9cSStella Laurenzo PyObjectRef(PyObjectRef &&other) 54436c6c9cSStella Laurenzo : referrent(other.referrent), object(std::move(other.object)) { 55436c6c9cSStella Laurenzo other.referrent = nullptr; 56436c6c9cSStella Laurenzo assert(!other.object); 57436c6c9cSStella Laurenzo } PyObjectRef(const PyObjectRef & other)58436c6c9cSStella Laurenzo PyObjectRef(const PyObjectRef &other) 59436c6c9cSStella Laurenzo : referrent(other.referrent), object(other.object /* copies */) {} 609940dcfaSMehdi Amini ~PyObjectRef() = default; 61436c6c9cSStella Laurenzo getRefCount()62436c6c9cSStella Laurenzo int getRefCount() { 63436c6c9cSStella Laurenzo if (!object) 64436c6c9cSStella Laurenzo return 0; 65436c6c9cSStella Laurenzo return object.ref_count(); 66436c6c9cSStella Laurenzo } 67436c6c9cSStella Laurenzo 68436c6c9cSStella Laurenzo /// Releases the object held by this instance, returning it. 69436c6c9cSStella Laurenzo /// This is the proper thing to return from a function that wants to return 70436c6c9cSStella Laurenzo /// the reference. Note that this does not work from initializers. releaseObject()71436c6c9cSStella Laurenzo pybind11::object releaseObject() { 72436c6c9cSStella Laurenzo assert(referrent && object); 73436c6c9cSStella Laurenzo referrent = nullptr; 74436c6c9cSStella Laurenzo auto stolen = std::move(object); 75436c6c9cSStella Laurenzo return stolen; 76436c6c9cSStella Laurenzo } 77436c6c9cSStella Laurenzo get()78436c6c9cSStella Laurenzo T *get() { return referrent; } 79436c6c9cSStella Laurenzo T *operator->() { 80436c6c9cSStella Laurenzo assert(referrent && object); 81436c6c9cSStella Laurenzo return referrent; 82436c6c9cSStella Laurenzo } getObject()83436c6c9cSStella Laurenzo pybind11::object getObject() { 84436c6c9cSStella Laurenzo assert(referrent && object); 85436c6c9cSStella Laurenzo return object; 86436c6c9cSStella Laurenzo } 87436c6c9cSStella Laurenzo operator bool() const { return referrent && object; } 88436c6c9cSStella Laurenzo 89436c6c9cSStella Laurenzo private: 90436c6c9cSStella Laurenzo T *referrent; 91436c6c9cSStella Laurenzo pybind11::object object; 92436c6c9cSStella Laurenzo }; 93436c6c9cSStella Laurenzo 94436c6c9cSStella Laurenzo /// Tracks an entry in the thread context stack. New entries are pushed onto 95436c6c9cSStella Laurenzo /// here for each with block that activates a new InsertionPoint, Context or 96436c6c9cSStella Laurenzo /// Location. 97436c6c9cSStella Laurenzo /// 98436c6c9cSStella Laurenzo /// Pushing either a Location or InsertionPoint also pushes its associated 99436c6c9cSStella Laurenzo /// Context. Pushing a Context will not modify the Location or InsertionPoint 100436c6c9cSStella Laurenzo /// unless if they are from a different context, in which case, they are 101436c6c9cSStella Laurenzo /// cleared. 102436c6c9cSStella Laurenzo class PyThreadContextEntry { 103436c6c9cSStella Laurenzo public: 104436c6c9cSStella Laurenzo enum class FrameKind { 105436c6c9cSStella Laurenzo Context, 106436c6c9cSStella Laurenzo InsertionPoint, 107436c6c9cSStella Laurenzo Location, 108436c6c9cSStella Laurenzo }; 109436c6c9cSStella Laurenzo PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)110436c6c9cSStella Laurenzo PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 111436c6c9cSStella Laurenzo pybind11::object insertionPoint, 112436c6c9cSStella Laurenzo pybind11::object location) 113436c6c9cSStella Laurenzo : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 114436c6c9cSStella Laurenzo location(std::move(location)), frameKind(frameKind) {} 115436c6c9cSStella Laurenzo 116436c6c9cSStella Laurenzo /// Gets the top of stack context and return nullptr if not defined. 117436c6c9cSStella Laurenzo static PyMlirContext *getDefaultContext(); 118436c6c9cSStella Laurenzo 119436c6c9cSStella Laurenzo /// Gets the top of stack insertion point and return nullptr if not defined. 120436c6c9cSStella Laurenzo static PyInsertionPoint *getDefaultInsertionPoint(); 121436c6c9cSStella Laurenzo 122436c6c9cSStella Laurenzo /// Gets the top of stack location and returns nullptr if not defined. 123436c6c9cSStella Laurenzo static PyLocation *getDefaultLocation(); 124436c6c9cSStella Laurenzo 125436c6c9cSStella Laurenzo PyMlirContext *getContext(); 126436c6c9cSStella Laurenzo PyInsertionPoint *getInsertionPoint(); 127436c6c9cSStella Laurenzo PyLocation *getLocation(); getFrameKind()128436c6c9cSStella Laurenzo FrameKind getFrameKind() { return frameKind; } 129436c6c9cSStella Laurenzo 130436c6c9cSStella Laurenzo /// Stack management. 131436c6c9cSStella Laurenzo static PyThreadContextEntry *getTopOfStack(); 132436c6c9cSStella Laurenzo static pybind11::object pushContext(PyMlirContext &context); 133436c6c9cSStella Laurenzo static void popContext(PyMlirContext &context); 134436c6c9cSStella Laurenzo static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 135436c6c9cSStella Laurenzo static void popInsertionPoint(PyInsertionPoint &insertionPoint); 136436c6c9cSStella Laurenzo static pybind11::object pushLocation(PyLocation &location); 137436c6c9cSStella Laurenzo static void popLocation(PyLocation &location); 138436c6c9cSStella Laurenzo 139436c6c9cSStella Laurenzo /// Gets the thread local stack. 140436c6c9cSStella Laurenzo static std::vector<PyThreadContextEntry> &getStack(); 141436c6c9cSStella Laurenzo 142436c6c9cSStella Laurenzo private: 143436c6c9cSStella Laurenzo static void push(FrameKind frameKind, pybind11::object context, 144436c6c9cSStella Laurenzo pybind11::object insertionPoint, pybind11::object location); 145436c6c9cSStella Laurenzo 146436c6c9cSStella Laurenzo /// An object reference to the PyContext. 147436c6c9cSStella Laurenzo pybind11::object context; 148436c6c9cSStella Laurenzo /// An object reference to the current insertion point. 149436c6c9cSStella Laurenzo pybind11::object insertionPoint; 150436c6c9cSStella Laurenzo /// An object reference to the current location. 151436c6c9cSStella Laurenzo pybind11::object location; 152436c6c9cSStella Laurenzo // The kind of push that was performed. 153436c6c9cSStella Laurenzo FrameKind frameKind; 154436c6c9cSStella Laurenzo }; 155436c6c9cSStella Laurenzo 156436c6c9cSStella Laurenzo /// Wrapper around MlirContext. 157436c6c9cSStella Laurenzo using PyMlirContextRef = PyObjectRef<PyMlirContext>; 158436c6c9cSStella Laurenzo class PyMlirContext { 159436c6c9cSStella Laurenzo public: 160436c6c9cSStella Laurenzo PyMlirContext() = delete; 161436c6c9cSStella Laurenzo PyMlirContext(const PyMlirContext &) = delete; 162436c6c9cSStella Laurenzo PyMlirContext(PyMlirContext &&) = delete; 163436c6c9cSStella Laurenzo 164436c6c9cSStella Laurenzo /// For the case of a python __init__ (py::init) method, pybind11 is quite 165436c6c9cSStella Laurenzo /// strict about needing to return a pointer that is not yet associated to 166436c6c9cSStella Laurenzo /// an py::object. Since the forContext() method acts like a pool, possibly 167436c6c9cSStella Laurenzo /// returning a recycled context, it does not satisfy this need. The usual 168436c6c9cSStella Laurenzo /// way in python to accomplish such a thing is to override __new__, but 169436c6c9cSStella Laurenzo /// that is also not supported by pybind11. Instead, we use this entry 170436c6c9cSStella Laurenzo /// point which always constructs a fresh context (which cannot alias an 171436c6c9cSStella Laurenzo /// existing one because it is fresh). 172436c6c9cSStella Laurenzo static PyMlirContext *createNewContextForInit(); 173436c6c9cSStella Laurenzo 174436c6c9cSStella Laurenzo /// Returns a context reference for the singleton PyMlirContext wrapper for 175436c6c9cSStella Laurenzo /// the given context. 176436c6c9cSStella Laurenzo static PyMlirContextRef forContext(MlirContext context); 177436c6c9cSStella Laurenzo ~PyMlirContext(); 178436c6c9cSStella Laurenzo 179436c6c9cSStella Laurenzo /// Accesses the underlying MlirContext. get()180436c6c9cSStella Laurenzo MlirContext get() { return context; } 181436c6c9cSStella Laurenzo 182436c6c9cSStella Laurenzo /// Gets a strong reference to this context, which will ensure it is kept 183436c6c9cSStella Laurenzo /// alive for the life of the reference. getRef()184436c6c9cSStella Laurenzo PyMlirContextRef getRef() { 185436c6c9cSStella Laurenzo return PyMlirContextRef(this, pybind11::cast(this)); 186436c6c9cSStella Laurenzo } 187436c6c9cSStella Laurenzo 188436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirContext. 189436c6c9cSStella Laurenzo pybind11::object getCapsule(); 190436c6c9cSStella Laurenzo 191436c6c9cSStella Laurenzo /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 192436c6c9cSStella Laurenzo /// Note that PyMlirContext instances are uniqued, so the returned object 193436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirContext 194436c6c9cSStella Laurenzo /// is taken by calling this function. 195436c6c9cSStella Laurenzo static pybind11::object createFromCapsule(pybind11::object capsule); 196436c6c9cSStella Laurenzo 197436c6c9cSStella Laurenzo /// Gets the count of live context objects. Used for testing. 198436c6c9cSStella Laurenzo static size_t getLiveCount(); 199436c6c9cSStella Laurenzo 200436c6c9cSStella Laurenzo /// Gets the count of live operations associated with this context. 201436c6c9cSStella Laurenzo /// Used for testing. 202436c6c9cSStella Laurenzo size_t getLiveOperationCount(); 203436c6c9cSStella Laurenzo 2046b0bed7eSJohn Demme /// Clears the live operations map, returning the number of entries which were 2056b0bed7eSJohn Demme /// invalidated. To be used as a safety mechanism so that API end-users can't 2066b0bed7eSJohn Demme /// corrupt by holding references they shouldn't have accessed in the first 2076b0bed7eSJohn Demme /// place. 2086b0bed7eSJohn Demme size_t clearLiveOperations(); 2096b0bed7eSJohn Demme 210436c6c9cSStella Laurenzo /// Gets the count of live modules associated with this context. 211436c6c9cSStella Laurenzo /// Used for testing. 212436c6c9cSStella Laurenzo size_t getLiveModuleCount(); 213436c6c9cSStella Laurenzo 214436c6c9cSStella Laurenzo /// Enter and exit the context manager. 215436c6c9cSStella Laurenzo pybind11::object contextEnter(); 2161fc096afSMehdi Amini void contextExit(const pybind11::object &excType, 2171fc096afSMehdi Amini const pybind11::object &excVal, 2181fc096afSMehdi Amini const pybind11::object &excTb); 219436c6c9cSStella Laurenzo 2207ee25bc5SStella Laurenzo /// Attaches a Python callback as a diagnostic handler, returning a 2217ee25bc5SStella Laurenzo /// registration object (internally a PyDiagnosticHandler). 2227ee25bc5SStella Laurenzo pybind11::object attachDiagnosticHandler(pybind11::object callback); 2237ee25bc5SStella Laurenzo 224436c6c9cSStella Laurenzo private: 225436c6c9cSStella Laurenzo PyMlirContext(MlirContext context); 226436c6c9cSStella Laurenzo // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 227436c6c9cSStella Laurenzo // preserving the relationship that an MlirContext maps to a single 228436c6c9cSStella Laurenzo // PyMlirContext wrapper. This could be replaced in the future with an 229436c6c9cSStella Laurenzo // extension mechanism on the MlirContext for stashing user pointers. 230436c6c9cSStella Laurenzo // Note that this holds a handle, which does not imply ownership. 231436c6c9cSStella Laurenzo // Mappings will be removed when the context is destructed. 232436c6c9cSStella Laurenzo using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 233436c6c9cSStella Laurenzo static LiveContextMap &getLiveContexts(); 234436c6c9cSStella Laurenzo 235436c6c9cSStella Laurenzo // Interns all live modules associated with this context. Modules tracked 236436c6c9cSStella Laurenzo // in this map are valid. When a module is invalidated, it is removed 237436c6c9cSStella Laurenzo // from this map, and while it still exists as an instance, any 238436c6c9cSStella Laurenzo // attempt to access it will raise an error. 239436c6c9cSStella Laurenzo using LiveModuleMap = 240436c6c9cSStella Laurenzo llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 241436c6c9cSStella Laurenzo LiveModuleMap liveModules; 242436c6c9cSStella Laurenzo 243436c6c9cSStella Laurenzo // Interns all live operations associated with this context. Operations 244436c6c9cSStella Laurenzo // tracked in this map are valid. When an operation is invalidated, it is 245436c6c9cSStella Laurenzo // removed from this map, and while it still exists as an instance, any 246436c6c9cSStella Laurenzo // attempt to access it will raise an error. 247436c6c9cSStella Laurenzo using LiveOperationMap = 248436c6c9cSStella Laurenzo llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 249436c6c9cSStella Laurenzo LiveOperationMap liveOperations; 250436c6c9cSStella Laurenzo 251436c6c9cSStella Laurenzo MlirContext context; 252436c6c9cSStella Laurenzo friend class PyModule; 253436c6c9cSStella Laurenzo friend class PyOperation; 254436c6c9cSStella Laurenzo }; 255436c6c9cSStella Laurenzo 256436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context 257436c6c9cSStella Laurenzo /// manager set instance. 258436c6c9cSStella Laurenzo class DefaultingPyMlirContext 259436c6c9cSStella Laurenzo : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 260436c6c9cSStella Laurenzo public: 261436c6c9cSStella Laurenzo using Defaulting::Defaulting; 262a6e7d024SStella Laurenzo static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 263436c6c9cSStella Laurenzo static PyMlirContext &resolve(); 264436c6c9cSStella Laurenzo }; 265436c6c9cSStella Laurenzo 266436c6c9cSStella Laurenzo /// Base class for all objects that directly or indirectly depend on an 267436c6c9cSStella Laurenzo /// MlirContext. The lifetime of the context will extend at least to the 268436c6c9cSStella Laurenzo /// lifetime of these instances. 269436c6c9cSStella Laurenzo /// Immutable objects that depend on a context extend this directly. 270436c6c9cSStella Laurenzo class BaseContextObject { 271436c6c9cSStella Laurenzo public: BaseContextObject(PyMlirContextRef ref)272436c6c9cSStella Laurenzo BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 273436c6c9cSStella Laurenzo assert(this->contextRef && 274436c6c9cSStella Laurenzo "context object constructed with null context ref"); 275436c6c9cSStella Laurenzo } 276436c6c9cSStella Laurenzo 277436c6c9cSStella Laurenzo /// Accesses the context reference. getContext()278436c6c9cSStella Laurenzo PyMlirContextRef &getContext() { return contextRef; } 279436c6c9cSStella Laurenzo 280436c6c9cSStella Laurenzo private: 281436c6c9cSStella Laurenzo PyMlirContextRef contextRef; 282436c6c9cSStella Laurenzo }; 283436c6c9cSStella Laurenzo 2847ee25bc5SStella Laurenzo /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 2857ee25bc5SStella Laurenzo /// are only valid for the duration of a diagnostic callback and attempting 2867ee25bc5SStella Laurenzo /// to access them outside of that will raise an exception. This applies to 2877ee25bc5SStella Laurenzo /// nested diagnostics (in the notes) as well. 2887ee25bc5SStella Laurenzo class PyDiagnostic { 2897ee25bc5SStella Laurenzo public: PyDiagnostic(MlirDiagnostic diagnostic)2907ee25bc5SStella Laurenzo PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 2917ee25bc5SStella Laurenzo void invalidate(); isValid()2927ee25bc5SStella Laurenzo bool isValid() { return valid; } 2937ee25bc5SStella Laurenzo MlirDiagnosticSeverity getSeverity(); 2947ee25bc5SStella Laurenzo PyLocation getLocation(); 2957ee25bc5SStella Laurenzo pybind11::str getMessage(); 2967ee25bc5SStella Laurenzo pybind11::tuple getNotes(); 2977ee25bc5SStella Laurenzo 2987ee25bc5SStella Laurenzo private: 2997ee25bc5SStella Laurenzo MlirDiagnostic diagnostic; 3007ee25bc5SStella Laurenzo 3017ee25bc5SStella Laurenzo void checkValid(); 3027ee25bc5SStella Laurenzo /// If notes have been materialized from the diagnostic, then this will 3037ee25bc5SStella Laurenzo /// be populated with the corresponding objects (all castable to 3047ee25bc5SStella Laurenzo /// PyDiagnostic). 3057ee25bc5SStella Laurenzo llvm::Optional<pybind11::tuple> materializedNotes; 3067ee25bc5SStella Laurenzo bool valid = true; 3077ee25bc5SStella Laurenzo }; 3087ee25bc5SStella Laurenzo 3097ee25bc5SStella Laurenzo /// Represents a diagnostic handler attached to the context. The handler's 3107ee25bc5SStella Laurenzo /// callback will be invoked with PyDiagnostic instances until the detach() 3117ee25bc5SStella Laurenzo /// method is called or the context is destroyed. A diagnostic handler can be 3127ee25bc5SStella Laurenzo /// the subject of a `with` block, which will detach it when the block exits. 3137ee25bc5SStella Laurenzo /// 3147ee25bc5SStella Laurenzo /// Since diagnostic handlers can call back into Python code which can do 3157ee25bc5SStella Laurenzo /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 3167ee25bc5SStella Laurenzo /// etc), this is generally not deemed to be a great user-level API. Users 3177ee25bc5SStella Laurenzo /// should generally use some form of DiagnosticCollector. If the handler raises 3187ee25bc5SStella Laurenzo /// any exceptions, they will just be emitted to stderr and dropped. 3197ee25bc5SStella Laurenzo /// 3207ee25bc5SStella Laurenzo /// The unique usage of this class means that its lifetime management is 3217ee25bc5SStella Laurenzo /// different from most other parts of the API. Instances are always created 3227ee25bc5SStella Laurenzo /// in an attached state and can transition to a detached state by either: 3237ee25bc5SStella Laurenzo /// a) The context being destroyed and unregistering all handlers. 3247ee25bc5SStella Laurenzo /// b) An explicit call to detach(). 3257ee25bc5SStella Laurenzo /// The object may remain live from a Python perspective for an arbitrary time 3267ee25bc5SStella Laurenzo /// after detachment, but there is nothing the user can do with it (since there 3277ee25bc5SStella Laurenzo /// is no way to attach an existing handler object). 3287ee25bc5SStella Laurenzo class PyDiagnosticHandler { 3297ee25bc5SStella Laurenzo public: 3307ee25bc5SStella Laurenzo PyDiagnosticHandler(MlirContext context, pybind11::object callback); 3317ee25bc5SStella Laurenzo ~PyDiagnosticHandler(); 3327ee25bc5SStella Laurenzo isAttached()3337ee25bc5SStella Laurenzo bool isAttached() { return registeredID.hasValue(); } getHadError()3347ee25bc5SStella Laurenzo bool getHadError() { return hadError; } 3357ee25bc5SStella Laurenzo 3367ee25bc5SStella Laurenzo /// Detaches the handler. Does nothing if not attached. 3377ee25bc5SStella Laurenzo void detach(); 3387ee25bc5SStella Laurenzo contextEnter()3397ee25bc5SStella Laurenzo pybind11::object contextEnter() { return pybind11::cast(this); } contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)340e8d07395SMehdi Amini void contextExit(const pybind11::object &excType, 341e8d07395SMehdi Amini const pybind11::object &excVal, 342e8d07395SMehdi Amini const pybind11::object &excTb) { 3437ee25bc5SStella Laurenzo detach(); 3447ee25bc5SStella Laurenzo } 3457ee25bc5SStella Laurenzo 3467ee25bc5SStella Laurenzo private: 3477ee25bc5SStella Laurenzo MlirContext context; 3487ee25bc5SStella Laurenzo pybind11::object callback; 3497ee25bc5SStella Laurenzo llvm::Optional<MlirDiagnosticHandlerID> registeredID; 3507ee25bc5SStella Laurenzo bool hadError = false; 3517ee25bc5SStella Laurenzo friend class PyMlirContext; 3527ee25bc5SStella Laurenzo }; 3537ee25bc5SStella Laurenzo 354436c6c9cSStella Laurenzo /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 355436c6c9cSStella Laurenzo /// order to differentiate it from the `Dialect` base class which is extended by 356436c6c9cSStella Laurenzo /// plugins which extend dialect functionality through extension python code. 357436c6c9cSStella Laurenzo /// This should be seen as the "low-level" object and `Dialect` as the 358436c6c9cSStella Laurenzo /// high-level, user facing object. 359436c6c9cSStella Laurenzo class PyDialectDescriptor : public BaseContextObject { 360436c6c9cSStella Laurenzo public: PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)361436c6c9cSStella Laurenzo PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 362436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 363436c6c9cSStella Laurenzo get()364436c6c9cSStella Laurenzo MlirDialect get() { return dialect; } 365436c6c9cSStella Laurenzo 366436c6c9cSStella Laurenzo private: 367436c6c9cSStella Laurenzo MlirDialect dialect; 368436c6c9cSStella Laurenzo }; 369436c6c9cSStella Laurenzo 370436c6c9cSStella Laurenzo /// User-level object for accessing dialects with dotted syntax such as: 371436c6c9cSStella Laurenzo /// ctx.dialect.std 372436c6c9cSStella Laurenzo class PyDialects : public BaseContextObject { 373436c6c9cSStella Laurenzo public: PyDialects(PyMlirContextRef contextRef)374436c6c9cSStella Laurenzo PyDialects(PyMlirContextRef contextRef) 375436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)) {} 376436c6c9cSStella Laurenzo 377436c6c9cSStella Laurenzo MlirDialect getDialectForKey(const std::string &key, bool attrError); 378436c6c9cSStella Laurenzo }; 379436c6c9cSStella Laurenzo 380436c6c9cSStella Laurenzo /// User-level dialect object. For dialects that have a registered extension, 381436c6c9cSStella Laurenzo /// this will be the base class of the extension dialect type. For un-extended, 382436c6c9cSStella Laurenzo /// objects of this type will be returned directly. 383436c6c9cSStella Laurenzo class PyDialect { 384436c6c9cSStella Laurenzo public: PyDialect(pybind11::object descriptor)385436c6c9cSStella Laurenzo PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 386436c6c9cSStella Laurenzo getDescriptor()387436c6c9cSStella Laurenzo pybind11::object getDescriptor() { return descriptor; } 388436c6c9cSStella Laurenzo 389436c6c9cSStella Laurenzo private: 390436c6c9cSStella Laurenzo pybind11::object descriptor; 391436c6c9cSStella Laurenzo }; 392436c6c9cSStella Laurenzo 393*5e83a5b4SStella Laurenzo /// Wrapper around an MlirDialectRegistry. 394*5e83a5b4SStella Laurenzo /// Upon construction, the Python wrapper takes ownership of the 395*5e83a5b4SStella Laurenzo /// underlying MlirDialectRegistry. 396*5e83a5b4SStella Laurenzo class PyDialectRegistry { 397*5e83a5b4SStella Laurenzo public: PyDialectRegistry()398*5e83a5b4SStella Laurenzo PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} PyDialectRegistry(MlirDialectRegistry registry)399*5e83a5b4SStella Laurenzo PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} ~PyDialectRegistry()400*5e83a5b4SStella Laurenzo ~PyDialectRegistry() { 401*5e83a5b4SStella Laurenzo if (!mlirDialectRegistryIsNull(registry)) 402*5e83a5b4SStella Laurenzo mlirDialectRegistryDestroy(registry); 403*5e83a5b4SStella Laurenzo } 404*5e83a5b4SStella Laurenzo PyDialectRegistry(PyDialectRegistry &) = delete; PyDialectRegistry(PyDialectRegistry && other)405*5e83a5b4SStella Laurenzo PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { 406*5e83a5b4SStella Laurenzo other.registry = {nullptr}; 407*5e83a5b4SStella Laurenzo } 408*5e83a5b4SStella Laurenzo MlirDialectRegistry()409*5e83a5b4SStella Laurenzo operator MlirDialectRegistry() const { return registry; } get()410*5e83a5b4SStella Laurenzo MlirDialectRegistry get() const { return registry; } 411*5e83a5b4SStella Laurenzo 412*5e83a5b4SStella Laurenzo pybind11::object getCapsule(); 413*5e83a5b4SStella Laurenzo static PyDialectRegistry createFromCapsule(pybind11::object capsule); 414*5e83a5b4SStella Laurenzo 415*5e83a5b4SStella Laurenzo private: 416*5e83a5b4SStella Laurenzo MlirDialectRegistry registry; 417*5e83a5b4SStella Laurenzo }; 418*5e83a5b4SStella Laurenzo 419436c6c9cSStella Laurenzo /// Wrapper around an MlirLocation. 420436c6c9cSStella Laurenzo class PyLocation : public BaseContextObject { 421436c6c9cSStella Laurenzo public: PyLocation(PyMlirContextRef contextRef,MlirLocation loc)422436c6c9cSStella Laurenzo PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 423436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), loc(loc) {} 424436c6c9cSStella Laurenzo MlirLocation()425436c6c9cSStella Laurenzo operator MlirLocation() const { return loc; } get()426436c6c9cSStella Laurenzo MlirLocation get() const { return loc; } 427436c6c9cSStella Laurenzo 428436c6c9cSStella Laurenzo /// Enter and exit the context manager. 429436c6c9cSStella Laurenzo pybind11::object contextEnter(); 4301fc096afSMehdi Amini void contextExit(const pybind11::object &excType, 4311fc096afSMehdi Amini const pybind11::object &excVal, 4321fc096afSMehdi Amini const pybind11::object &excTb); 433436c6c9cSStella Laurenzo 434436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirLocation. 435436c6c9cSStella Laurenzo pybind11::object getCapsule(); 436436c6c9cSStella Laurenzo 437436c6c9cSStella Laurenzo /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 438436c6c9cSStella Laurenzo /// Note that PyLocation instances are uniqued, so the returned object 439436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirLocation 440436c6c9cSStella Laurenzo /// is taken by calling this function. 441436c6c9cSStella Laurenzo static PyLocation createFromCapsule(pybind11::object capsule); 442436c6c9cSStella Laurenzo 443436c6c9cSStella Laurenzo private: 444436c6c9cSStella Laurenzo MlirLocation loc; 445436c6c9cSStella Laurenzo }; 446436c6c9cSStella Laurenzo 447436c6c9cSStella Laurenzo /// Used in function arguments when None should resolve to the current context 448436c6c9cSStella Laurenzo /// manager set instance. 449436c6c9cSStella Laurenzo class DefaultingPyLocation 450436c6c9cSStella Laurenzo : public Defaulting<DefaultingPyLocation, PyLocation> { 451436c6c9cSStella Laurenzo public: 452436c6c9cSStella Laurenzo using Defaulting::Defaulting; 453a6e7d024SStella Laurenzo static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 454436c6c9cSStella Laurenzo static PyLocation &resolve(); 455436c6c9cSStella Laurenzo MlirLocation()456436c6c9cSStella Laurenzo operator MlirLocation() const { return *get(); } 457436c6c9cSStella Laurenzo }; 458436c6c9cSStella Laurenzo 459436c6c9cSStella Laurenzo /// Wrapper around MlirModule. 460436c6c9cSStella Laurenzo /// This is the top-level, user-owned object that contains regions/ops/blocks. 461436c6c9cSStella Laurenzo class PyModule; 462436c6c9cSStella Laurenzo using PyModuleRef = PyObjectRef<PyModule>; 463436c6c9cSStella Laurenzo class PyModule : public BaseContextObject { 464436c6c9cSStella Laurenzo public: 465436c6c9cSStella Laurenzo /// Returns a PyModule reference for the given MlirModule. This may return 466436c6c9cSStella Laurenzo /// a pre-existing or new object. 467436c6c9cSStella Laurenzo static PyModuleRef forModule(MlirModule module); 468436c6c9cSStella Laurenzo PyModule(PyModule &) = delete; 469436c6c9cSStella Laurenzo PyModule(PyMlirContext &&) = delete; 470436c6c9cSStella Laurenzo ~PyModule(); 471436c6c9cSStella Laurenzo 472436c6c9cSStella Laurenzo /// Gets the backing MlirModule. get()473436c6c9cSStella Laurenzo MlirModule get() { return module; } 474436c6c9cSStella Laurenzo 475436c6c9cSStella Laurenzo /// Gets a strong reference to this module. getRef()476436c6c9cSStella Laurenzo PyModuleRef getRef() { 477436c6c9cSStella Laurenzo return PyModuleRef(this, 478436c6c9cSStella Laurenzo pybind11::reinterpret_borrow<pybind11::object>(handle)); 479436c6c9cSStella Laurenzo } 480436c6c9cSStella Laurenzo 481436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirModule. 482436c6c9cSStella Laurenzo /// Note that the module does not (yet) provide a corresponding factory for 483436c6c9cSStella Laurenzo /// constructing from a capsule as that would require uniquing PyModule 484436c6c9cSStella Laurenzo /// instances, which is not currently done. 485436c6c9cSStella Laurenzo pybind11::object getCapsule(); 486436c6c9cSStella Laurenzo 487436c6c9cSStella Laurenzo /// Creates a PyModule from the MlirModule wrapped by a capsule. 488436c6c9cSStella Laurenzo /// Note that PyModule instances are uniqued, so the returned object 489436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirModule 490436c6c9cSStella Laurenzo /// is taken by calling this function. 491436c6c9cSStella Laurenzo static pybind11::object createFromCapsule(pybind11::object capsule); 492436c6c9cSStella Laurenzo 493436c6c9cSStella Laurenzo private: 494436c6c9cSStella Laurenzo PyModule(PyMlirContextRef contextRef, MlirModule module); 495436c6c9cSStella Laurenzo MlirModule module; 496436c6c9cSStella Laurenzo pybind11::handle handle; 497436c6c9cSStella Laurenzo }; 498436c6c9cSStella Laurenzo 499436c6c9cSStella Laurenzo /// Base class for PyOperation and PyOpView which exposes the primary, user 500436c6c9cSStella Laurenzo /// visible methods for manipulating it. 501436c6c9cSStella Laurenzo class PyOperationBase { 502436c6c9cSStella Laurenzo public: 503436c6c9cSStella Laurenzo virtual ~PyOperationBase() = default; 504436c6c9cSStella Laurenzo /// Implements the bound 'print' method and helps with others. 505436c6c9cSStella Laurenzo void print(pybind11::object fileObject, bool binary, 506436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 507ace1d0adSStella Laurenzo bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 508ace1d0adSStella Laurenzo bool assumeVerified); 509436c6c9cSStella Laurenzo pybind11::object getAsm(bool binary, 510436c6c9cSStella Laurenzo llvm::Optional<int64_t> largeElementsLimit, 511436c6c9cSStella Laurenzo bool enableDebugInfo, bool prettyDebugInfo, 512ace1d0adSStella Laurenzo bool printGenericOpForm, bool useLocalScope, 513ace1d0adSStella Laurenzo bool assumeVerified); 514436c6c9cSStella Laurenzo 51524685aaeSAlex Zinenko /// Moves the operation before or after the other operation. 51624685aaeSAlex Zinenko void moveAfter(PyOperationBase &other); 51724685aaeSAlex Zinenko void moveBefore(PyOperationBase &other); 51824685aaeSAlex Zinenko 519436c6c9cSStella Laurenzo /// Each must provide access to the raw Operation. 520436c6c9cSStella Laurenzo virtual PyOperation &getOperation() = 0; 521436c6c9cSStella Laurenzo }; 522436c6c9cSStella Laurenzo 523436c6c9cSStella Laurenzo /// Wrapper around PyOperation. 524436c6c9cSStella Laurenzo /// Operations exist in either an attached (dependent) or detached (top-level) 525436c6c9cSStella Laurenzo /// state. In the detached state (as on creation), an operation is owned by 526436c6c9cSStella Laurenzo /// the creator and its lifetime extends either until its reference count 527436c6c9cSStella Laurenzo /// drops to zero or it is attached to a parent, at which point its lifetime 528436c6c9cSStella Laurenzo /// is bounded by its top-level parent reference. 529436c6c9cSStella Laurenzo class PyOperation; 530436c6c9cSStella Laurenzo using PyOperationRef = PyObjectRef<PyOperation>; 531436c6c9cSStella Laurenzo class PyOperation : public PyOperationBase, public BaseContextObject { 532436c6c9cSStella Laurenzo public: 533bd87241cSMehdi Amini ~PyOperation() override; getOperation()534436c6c9cSStella Laurenzo PyOperation &getOperation() override { return *this; } 535436c6c9cSStella Laurenzo 536436c6c9cSStella Laurenzo /// Returns a PyOperation for the given MlirOperation, optionally associating 537436c6c9cSStella Laurenzo /// it with a parentKeepAlive. 538436c6c9cSStella Laurenzo static PyOperationRef 539436c6c9cSStella Laurenzo forOperation(PyMlirContextRef contextRef, MlirOperation operation, 540436c6c9cSStella Laurenzo pybind11::object parentKeepAlive = pybind11::object()); 541436c6c9cSStella Laurenzo 542436c6c9cSStella Laurenzo /// Creates a detached operation. The operation must not be associated with 543436c6c9cSStella Laurenzo /// any existing live operation. 544436c6c9cSStella Laurenzo static PyOperationRef 545436c6c9cSStella Laurenzo createDetached(PyMlirContextRef contextRef, MlirOperation operation, 546436c6c9cSStella Laurenzo pybind11::object parentKeepAlive = pybind11::object()); 547436c6c9cSStella Laurenzo 54824685aaeSAlex Zinenko /// Detaches the operation from its parent block and updates its state 54924685aaeSAlex Zinenko /// accordingly. detachFromParent()55024685aaeSAlex Zinenko void detachFromParent() { 55124685aaeSAlex Zinenko mlirOperationRemoveFromParent(getOperation()); 55224685aaeSAlex Zinenko setDetached(); 55324685aaeSAlex Zinenko parentKeepAlive = pybind11::object(); 55424685aaeSAlex Zinenko } 55524685aaeSAlex Zinenko 556436c6c9cSStella Laurenzo /// Gets the backing operation. MlirOperation()557436c6c9cSStella Laurenzo operator MlirOperation() const { return get(); } get()558436c6c9cSStella Laurenzo MlirOperation get() const { 559436c6c9cSStella Laurenzo checkValid(); 560436c6c9cSStella Laurenzo return operation; 561436c6c9cSStella Laurenzo } 562436c6c9cSStella Laurenzo getRef()563436c6c9cSStella Laurenzo PyOperationRef getRef() { 564436c6c9cSStella Laurenzo return PyOperationRef( 565436c6c9cSStella Laurenzo this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 566436c6c9cSStella Laurenzo } 567436c6c9cSStella Laurenzo isAttached()568436c6c9cSStella Laurenzo bool isAttached() { return attached; } 569e8d07395SMehdi Amini void setAttached(const pybind11::object &parent = pybind11::object()) { 570436c6c9cSStella Laurenzo assert(!attached && "operation already attached"); 571436c6c9cSStella Laurenzo attached = true; 572436c6c9cSStella Laurenzo } setDetached()57324685aaeSAlex Zinenko void setDetached() { 57424685aaeSAlex Zinenko assert(attached && "operation already detached"); 57524685aaeSAlex Zinenko attached = false; 57624685aaeSAlex Zinenko } 577436c6c9cSStella Laurenzo void checkValid() const; 578436c6c9cSStella Laurenzo 579436c6c9cSStella Laurenzo /// Gets the owning block or raises an exception if the operation has no 580436c6c9cSStella Laurenzo /// owning block. 581436c6c9cSStella Laurenzo PyBlock getBlock(); 582436c6c9cSStella Laurenzo 583436c6c9cSStella Laurenzo /// Gets the parent operation or raises an exception if the operation has 584436c6c9cSStella Laurenzo /// no parent. 5851689dadeSJohn Demme llvm::Optional<PyOperationRef> getParentOperation(); 586436c6c9cSStella Laurenzo 5870126e906SJohn Demme /// Gets a capsule wrapping the void* within the MlirOperation. 5880126e906SJohn Demme pybind11::object getCapsule(); 5890126e906SJohn Demme 5900126e906SJohn Demme /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 5910126e906SJohn Demme /// Ownership of the underlying MlirOperation is taken by calling this 5920126e906SJohn Demme /// function. 5930126e906SJohn Demme static pybind11::object createFromCapsule(pybind11::object capsule); 5940126e906SJohn Demme 595436c6c9cSStella Laurenzo /// Creates an operation. See corresponding python docstring. 596436c6c9cSStella Laurenzo static pybind11::object 5971fc096afSMehdi Amini create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 598436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyValue *>> operands, 599436c6c9cSStella Laurenzo llvm::Optional<pybind11::dict> attributes, 600436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, int regions, 6011fc096afSMehdi Amini DefaultingPyLocation location, const pybind11::object &ip); 602436c6c9cSStella Laurenzo 603436c6c9cSStella Laurenzo /// Creates an OpView suitable for this operation. 604436c6c9cSStella Laurenzo pybind11::object createOpView(); 605436c6c9cSStella Laurenzo 60649745f87SMike Urbach /// Erases the underlying MlirOperation, removes its pointer from the 60749745f87SMike Urbach /// parent context's live operations map, and sets the valid bit false. 60849745f87SMike Urbach void erase(); 60949745f87SMike Urbach 6106b0bed7eSJohn Demme /// Invalidate the operation. setInvalid()6116b0bed7eSJohn Demme void setInvalid() { valid = false; } 6126b0bed7eSJohn Demme 613774818c0SDominik Grewe /// Clones this operation. 614774818c0SDominik Grewe pybind11::object clone(const pybind11::object &ip); 615774818c0SDominik Grewe 616436c6c9cSStella Laurenzo private: 617436c6c9cSStella Laurenzo PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 618436c6c9cSStella Laurenzo static PyOperationRef createInstance(PyMlirContextRef contextRef, 619436c6c9cSStella Laurenzo MlirOperation operation, 620436c6c9cSStella Laurenzo pybind11::object parentKeepAlive); 621436c6c9cSStella Laurenzo 622436c6c9cSStella Laurenzo MlirOperation operation; 623436c6c9cSStella Laurenzo pybind11::handle handle; 624436c6c9cSStella Laurenzo // Keeps the parent alive, regardless of whether it is an Operation or 625436c6c9cSStella Laurenzo // Module. 626436c6c9cSStella Laurenzo // TODO: As implemented, this facility is only sufficient for modeling the 627436c6c9cSStella Laurenzo // trivial module parent back-reference. Generalize this to also account for 628436c6c9cSStella Laurenzo // transitions from detached to attached and address TODOs in the 629436c6c9cSStella Laurenzo // ir_operation.py regarding testing corresponding lifetime guarantees. 630436c6c9cSStella Laurenzo pybind11::object parentKeepAlive; 631436c6c9cSStella Laurenzo bool attached = true; 632436c6c9cSStella Laurenzo bool valid = true; 63324685aaeSAlex Zinenko 63424685aaeSAlex Zinenko friend class PyOperationBase; 63530d61893SAlex Zinenko friend class PySymbolTable; 636436c6c9cSStella Laurenzo }; 637436c6c9cSStella Laurenzo 638436c6c9cSStella Laurenzo /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 639436c6c9cSStella Laurenzo /// providing more instance-specific accessors and serve as the base class for 640436c6c9cSStella Laurenzo /// custom ODS-style operation classes. Since this class is subclass on the 641436c6c9cSStella Laurenzo /// python side, it must present an __init__ method that operates in pure 642436c6c9cSStella Laurenzo /// python types. 643436c6c9cSStella Laurenzo class PyOpView : public PyOperationBase { 644436c6c9cSStella Laurenzo public: 6451fc096afSMehdi Amini PyOpView(const pybind11::object &operationObject); getOperation()646436c6c9cSStella Laurenzo PyOperation &getOperation() override { return operation; } 647436c6c9cSStella Laurenzo 6481fc096afSMehdi Amini static pybind11::object createRawSubclass(const pybind11::object &userClass); 649436c6c9cSStella Laurenzo getOperationObject()650436c6c9cSStella Laurenzo pybind11::object getOperationObject() { return operationObject; } 651436c6c9cSStella Laurenzo 652436c6c9cSStella Laurenzo static pybind11::object 6531fc096afSMehdi Amini buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 654436c6c9cSStella Laurenzo pybind11::list operandList, 655436c6c9cSStella Laurenzo llvm::Optional<pybind11::dict> attributes, 656436c6c9cSStella Laurenzo llvm::Optional<std::vector<PyBlock *>> successors, 657436c6c9cSStella Laurenzo llvm::Optional<int> regions, DefaultingPyLocation location, 6584f415216SMehdi Amini const pybind11::object &maybeIp); 659436c6c9cSStella Laurenzo 660436c6c9cSStella Laurenzo private: 661436c6c9cSStella Laurenzo PyOperation &operation; // For efficient, cast-free access from C++ 662436c6c9cSStella Laurenzo pybind11::object operationObject; // Holds the reference. 663436c6c9cSStella Laurenzo }; 664436c6c9cSStella Laurenzo 665436c6c9cSStella Laurenzo /// Wrapper around an MlirRegion. 666436c6c9cSStella Laurenzo /// Regions are managed completely by their containing operation. Unlike the 667436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached regions. 668436c6c9cSStella Laurenzo class PyRegion { 669436c6c9cSStella Laurenzo public: PyRegion(PyOperationRef parentOperation,MlirRegion region)670436c6c9cSStella Laurenzo PyRegion(PyOperationRef parentOperation, MlirRegion region) 671436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), region(region) { 672436c6c9cSStella Laurenzo assert(!mlirRegionIsNull(region) && "python region cannot be null"); 673436c6c9cSStella Laurenzo } MlirRegion()67478f2dae0SAlex Zinenko operator MlirRegion() const { return region; } 675436c6c9cSStella Laurenzo get()676436c6c9cSStella Laurenzo MlirRegion get() { return region; } getParentOperation()677436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 678436c6c9cSStella Laurenzo checkValid()679436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 680436c6c9cSStella Laurenzo 681436c6c9cSStella Laurenzo private: 682436c6c9cSStella Laurenzo PyOperationRef parentOperation; 683436c6c9cSStella Laurenzo MlirRegion region; 684436c6c9cSStella Laurenzo }; 685436c6c9cSStella Laurenzo 686436c6c9cSStella Laurenzo /// Wrapper around an MlirBlock. 687436c6c9cSStella Laurenzo /// Blocks are managed completely by their containing operation. Unlike the 688436c6c9cSStella Laurenzo /// C++ API, the python API does not support detached blocks. 689436c6c9cSStella Laurenzo class PyBlock { 690436c6c9cSStella Laurenzo public: PyBlock(PyOperationRef parentOperation,MlirBlock block)691436c6c9cSStella Laurenzo PyBlock(PyOperationRef parentOperation, MlirBlock block) 692436c6c9cSStella Laurenzo : parentOperation(std::move(parentOperation)), block(block) { 693436c6c9cSStella Laurenzo assert(!mlirBlockIsNull(block) && "python block cannot be null"); 694436c6c9cSStella Laurenzo } 695436c6c9cSStella Laurenzo get()696436c6c9cSStella Laurenzo MlirBlock get() { return block; } getParentOperation()697436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 698436c6c9cSStella Laurenzo checkValid()699436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 700436c6c9cSStella Laurenzo 701436c6c9cSStella Laurenzo private: 702436c6c9cSStella Laurenzo PyOperationRef parentOperation; 703436c6c9cSStella Laurenzo MlirBlock block; 704436c6c9cSStella Laurenzo }; 705436c6c9cSStella Laurenzo 706436c6c9cSStella Laurenzo /// An insertion point maintains a pointer to a Block and a reference operation. 707436c6c9cSStella Laurenzo /// Calls to insert() will insert a new operation before the 708436c6c9cSStella Laurenzo /// reference operation. If the reference operation is null, then appends to 709436c6c9cSStella Laurenzo /// the end of the block. 710436c6c9cSStella Laurenzo class PyInsertionPoint { 711436c6c9cSStella Laurenzo public: 712436c6c9cSStella Laurenzo /// Creates an insertion point positioned after the last operation in the 713436c6c9cSStella Laurenzo /// block, but still inside the block. 714436c6c9cSStella Laurenzo PyInsertionPoint(PyBlock &block); 715436c6c9cSStella Laurenzo /// Creates an insertion point positioned before a reference operation. 716436c6c9cSStella Laurenzo PyInsertionPoint(PyOperationBase &beforeOperationBase); 717436c6c9cSStella Laurenzo 718436c6c9cSStella Laurenzo /// Shortcut to create an insertion point at the beginning of the block. 719436c6c9cSStella Laurenzo static PyInsertionPoint atBlockBegin(PyBlock &block); 720436c6c9cSStella Laurenzo /// Shortcut to create an insertion point before the block terminator. 721436c6c9cSStella Laurenzo static PyInsertionPoint atBlockTerminator(PyBlock &block); 722436c6c9cSStella Laurenzo 723436c6c9cSStella Laurenzo /// Inserts an operation. 724436c6c9cSStella Laurenzo void insert(PyOperationBase &operationBase); 725436c6c9cSStella Laurenzo 726436c6c9cSStella Laurenzo /// Enter and exit the context manager. 727436c6c9cSStella Laurenzo pybind11::object contextEnter(); 7281fc096afSMehdi Amini void contextExit(const pybind11::object &excType, 7291fc096afSMehdi Amini const pybind11::object &excVal, 7301fc096afSMehdi Amini const pybind11::object &excTb); 731436c6c9cSStella Laurenzo getBlock()732436c6c9cSStella Laurenzo PyBlock &getBlock() { return block; } 733436c6c9cSStella Laurenzo 734436c6c9cSStella Laurenzo private: 735436c6c9cSStella Laurenzo // Trampoline constructor that avoids null initializing members while 736436c6c9cSStella Laurenzo // looking up parents. PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)737436c6c9cSStella Laurenzo PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 738436c6c9cSStella Laurenzo : refOperation(std::move(refOperation)), block(std::move(block)) {} 739436c6c9cSStella Laurenzo 740436c6c9cSStella Laurenzo llvm::Optional<PyOperationRef> refOperation; 741436c6c9cSStella Laurenzo PyBlock block; 742436c6c9cSStella Laurenzo }; 7432995d29bSAlex Zinenko /// Wrapper around the generic MlirType. 7442995d29bSAlex Zinenko /// The lifetime of a type is bound by the PyContext that created it. 7452995d29bSAlex Zinenko class PyType : public BaseContextObject { 7462995d29bSAlex Zinenko public: PyType(PyMlirContextRef contextRef,MlirType type)7472995d29bSAlex Zinenko PyType(PyMlirContextRef contextRef, MlirType type) 7482995d29bSAlex Zinenko : BaseContextObject(std::move(contextRef)), type(type) {} 7492995d29bSAlex Zinenko bool operator==(const PyType &other); MlirType()7502995d29bSAlex Zinenko operator MlirType() const { return type; } get()7512995d29bSAlex Zinenko MlirType get() const { return type; } 7522995d29bSAlex Zinenko 7532995d29bSAlex Zinenko /// Gets a capsule wrapping the void* within the MlirType. 7542995d29bSAlex Zinenko pybind11::object getCapsule(); 7552995d29bSAlex Zinenko 7562995d29bSAlex Zinenko /// Creates a PyType from the MlirType wrapped by a capsule. 7572995d29bSAlex Zinenko /// Note that PyType instances are uniqued, so the returned object 7582995d29bSAlex Zinenko /// may be a pre-existing object. Ownership of the underlying MlirType 7592995d29bSAlex Zinenko /// is taken by calling this function. 7602995d29bSAlex Zinenko static PyType createFromCapsule(pybind11::object capsule); 7612995d29bSAlex Zinenko 7622995d29bSAlex Zinenko private: 7632995d29bSAlex Zinenko MlirType type; 7642995d29bSAlex Zinenko }; 7652995d29bSAlex Zinenko 7662995d29bSAlex Zinenko /// CRTP base classes for Python types that subclass Type and should be 7672995d29bSAlex Zinenko /// castable from it (i.e. via something like IntegerType(t)). 7682995d29bSAlex Zinenko /// By default, type class hierarchies are one level deep (i.e. a 7692995d29bSAlex Zinenko /// concrete type class extends PyType); however, intermediate python-visible 7702995d29bSAlex Zinenko /// base classes can be modeled by specifying a BaseTy. 7712995d29bSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyType> 7722995d29bSAlex Zinenko class PyConcreteType : public BaseTy { 7732995d29bSAlex Zinenko public: 7742995d29bSAlex Zinenko // Derived classes must define statics for: 7752995d29bSAlex Zinenko // IsAFunctionTy isaFunction 7762995d29bSAlex Zinenko // const char *pyClassName 7772995d29bSAlex Zinenko using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 7782995d29bSAlex Zinenko using IsAFunctionTy = bool (*)(MlirType); 7792995d29bSAlex Zinenko 7802995d29bSAlex Zinenko PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef,MlirType t)7812995d29bSAlex Zinenko PyConcreteType(PyMlirContextRef contextRef, MlirType t) 7822995d29bSAlex Zinenko : BaseTy(std::move(contextRef), t) {} PyConcreteType(PyType & orig)7832995d29bSAlex Zinenko PyConcreteType(PyType &orig) 7842995d29bSAlex Zinenko : PyConcreteType(orig.getContext(), castFrom(orig)) {} 7852995d29bSAlex Zinenko castFrom(PyType & orig)7862995d29bSAlex Zinenko static MlirType castFrom(PyType &orig) { 7872995d29bSAlex Zinenko if (!DerivedTy::isaFunction(orig)) { 7882995d29bSAlex Zinenko auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 7892995d29bSAlex Zinenko throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 7902995d29bSAlex Zinenko DerivedTy::pyClassName + 7912995d29bSAlex Zinenko " (from " + origRepr + ")"); 7922995d29bSAlex Zinenko } 7932995d29bSAlex Zinenko return orig; 7942995d29bSAlex Zinenko } 7952995d29bSAlex Zinenko bind(pybind11::module & m)7962995d29bSAlex Zinenko static void bind(pybind11::module &m) { 7972995d29bSAlex Zinenko auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 798a6e7d024SStella Laurenzo cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 799a6e7d024SStella Laurenzo pybind11::arg("cast_from_type")); 800a6e7d024SStella Laurenzo cls.def_static( 801a6e7d024SStella Laurenzo "isinstance", 802a6e7d024SStella Laurenzo [](PyType &otherType) -> bool { 8032995d29bSAlex Zinenko return DerivedTy::isaFunction(otherType); 804a6e7d024SStella Laurenzo }, 805a6e7d024SStella Laurenzo pybind11::arg("other")); 8062995d29bSAlex Zinenko DerivedTy::bindDerived(cls); 8072995d29bSAlex Zinenko } 8082995d29bSAlex Zinenko 8092995d29bSAlex Zinenko /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)8102995d29bSAlex Zinenko static void bindDerived(ClassTy &m) {} 8112995d29bSAlex Zinenko }; 812436c6c9cSStella Laurenzo 813436c6c9cSStella Laurenzo /// Wrapper around the generic MlirAttribute. 814436c6c9cSStella Laurenzo /// The lifetime of a type is bound by the PyContext that created it. 815436c6c9cSStella Laurenzo class PyAttribute : public BaseContextObject { 816436c6c9cSStella Laurenzo public: PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)817436c6c9cSStella Laurenzo PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 818436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), attr(attr) {} 819436c6c9cSStella Laurenzo bool operator==(const PyAttribute &other); MlirAttribute()820436c6c9cSStella Laurenzo operator MlirAttribute() const { return attr; } get()821436c6c9cSStella Laurenzo MlirAttribute get() const { return attr; } 822436c6c9cSStella Laurenzo 823436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAttribute. 824436c6c9cSStella Laurenzo pybind11::object getCapsule(); 825436c6c9cSStella Laurenzo 826436c6c9cSStella Laurenzo /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 827436c6c9cSStella Laurenzo /// Note that PyAttribute instances are uniqued, so the returned object 828436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAttribute 829436c6c9cSStella Laurenzo /// is taken by calling this function. 830436c6c9cSStella Laurenzo static PyAttribute createFromCapsule(pybind11::object capsule); 831436c6c9cSStella Laurenzo 832436c6c9cSStella Laurenzo private: 833436c6c9cSStella Laurenzo MlirAttribute attr; 834436c6c9cSStella Laurenzo }; 835436c6c9cSStella Laurenzo 836436c6c9cSStella Laurenzo /// Represents a Python MlirNamedAttr, carrying an optional owned name. 837436c6c9cSStella Laurenzo /// TODO: Refactor this and the C-API to be based on an Identifier owned 838436c6c9cSStella Laurenzo /// by the context so as to avoid ownership issues here. 839436c6c9cSStella Laurenzo class PyNamedAttribute { 840436c6c9cSStella Laurenzo public: 841436c6c9cSStella Laurenzo /// Constructs a PyNamedAttr that retains an owned name. This should be 842436c6c9cSStella Laurenzo /// used in any code that originates an MlirNamedAttribute from a python 843436c6c9cSStella Laurenzo /// string. 844436c6c9cSStella Laurenzo /// The lifetime of the PyNamedAttr must extend to the lifetime of the 845436c6c9cSStella Laurenzo /// passed attribute. 846436c6c9cSStella Laurenzo PyNamedAttribute(MlirAttribute attr, std::string ownedName); 847436c6c9cSStella Laurenzo 848436c6c9cSStella Laurenzo MlirNamedAttribute namedAttr; 849436c6c9cSStella Laurenzo 850436c6c9cSStella Laurenzo private: 851436c6c9cSStella Laurenzo // Since the MlirNamedAttr contains an internal pointer to the actual 852436c6c9cSStella Laurenzo // memory of the owned string, it must be heap allocated to remain valid. 853436c6c9cSStella Laurenzo // Otherwise, strings that fit within the small object optimization threshold 854436c6c9cSStella Laurenzo // will have their memory address change as the containing object is moved, 855436c6c9cSStella Laurenzo // resulting in an invalid aliased pointer. 856436c6c9cSStella Laurenzo std::unique_ptr<std::string> ownedName; 857436c6c9cSStella Laurenzo }; 858436c6c9cSStella Laurenzo 8590b10fdedSAlex Zinenko /// CRTP base classes for Python attributes that subclass Attribute and should 8600b10fdedSAlex Zinenko /// be castable from it (i.e. via something like StringAttr(attr)). 8610b10fdedSAlex Zinenko /// By default, attribute class hierarchies are one level deep (i.e. a 8620b10fdedSAlex Zinenko /// concrete attribute class extends PyAttribute); however, intermediate 8630b10fdedSAlex Zinenko /// python-visible base classes can be modeled by specifying a BaseTy. 8640b10fdedSAlex Zinenko template <typename DerivedTy, typename BaseTy = PyAttribute> 8650b10fdedSAlex Zinenko class PyConcreteAttribute : public BaseTy { 8660b10fdedSAlex Zinenko public: 8670b10fdedSAlex Zinenko // Derived classes must define statics for: 8680b10fdedSAlex Zinenko // IsAFunctionTy isaFunction 8690b10fdedSAlex Zinenko // const char *pyClassName 8700b10fdedSAlex Zinenko using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 8710b10fdedSAlex Zinenko using IsAFunctionTy = bool (*)(MlirAttribute); 8720b10fdedSAlex Zinenko 8730b10fdedSAlex Zinenko PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)8740b10fdedSAlex Zinenko PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 8750b10fdedSAlex Zinenko : BaseTy(std::move(contextRef), attr) {} PyConcreteAttribute(PyAttribute & orig)8760b10fdedSAlex Zinenko PyConcreteAttribute(PyAttribute &orig) 8770b10fdedSAlex Zinenko : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 8780b10fdedSAlex Zinenko castFrom(PyAttribute & orig)8790b10fdedSAlex Zinenko static MlirAttribute castFrom(PyAttribute &orig) { 8800b10fdedSAlex Zinenko if (!DerivedTy::isaFunction(orig)) { 8810b10fdedSAlex Zinenko auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 8820b10fdedSAlex Zinenko throw SetPyError(PyExc_ValueError, 8830b10fdedSAlex Zinenko llvm::Twine("Cannot cast attribute to ") + 8840b10fdedSAlex Zinenko DerivedTy::pyClassName + " (from " + origRepr + ")"); 8850b10fdedSAlex Zinenko } 8860b10fdedSAlex Zinenko return orig; 8870b10fdedSAlex Zinenko } 8880b10fdedSAlex Zinenko bind(pybind11::module & m)8890b10fdedSAlex Zinenko static void bind(pybind11::module &m) { 8908dca953dSSean Silva auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 8918dca953dSSean Silva pybind11::module_local()); 892a6e7d024SStella Laurenzo cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 893a6e7d024SStella Laurenzo pybind11::arg("cast_from_attr")); 894a6e7d024SStella Laurenzo cls.def_static( 895a6e7d024SStella Laurenzo "isinstance", 896a6e7d024SStella Laurenzo [](PyAttribute &otherAttr) -> bool { 89778f2dae0SAlex Zinenko return DerivedTy::isaFunction(otherAttr); 898a6e7d024SStella Laurenzo }, 899a6e7d024SStella Laurenzo pybind11::arg("other")); 9002995d29bSAlex Zinenko cls.def_property_readonly("type", [](PyAttribute &attr) { 9012995d29bSAlex Zinenko return PyType(attr.getContext(), mlirAttributeGetType(attr)); 90232e2fec7SJohn Demme }); 90332e2fec7SJohn Demme DerivedTy::bindDerived(cls); 90432e2fec7SJohn Demme } 90532e2fec7SJohn Demme 90632e2fec7SJohn Demme /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)90732e2fec7SJohn Demme static void bindDerived(ClassTy &m) {} 90832e2fec7SJohn Demme }; 90932e2fec7SJohn Demme 910436c6c9cSStella Laurenzo /// Wrapper around the generic MlirValue. 911436c6c9cSStella Laurenzo /// Values are managed completely by the operation that resulted in their 912436c6c9cSStella Laurenzo /// definition. For op result value, this is the operation that defines the 913436c6c9cSStella Laurenzo /// value. For block argument values, this is the operation that contains the 914436c6c9cSStella Laurenzo /// block to which the value is an argument (blocks cannot be detached in Python 915436c6c9cSStella Laurenzo /// bindings so such operation always exists). 916436c6c9cSStella Laurenzo class PyValue { 917436c6c9cSStella Laurenzo public: PyValue(PyOperationRef parentOperation,MlirValue value)918436c6c9cSStella Laurenzo PyValue(PyOperationRef parentOperation, MlirValue value) 919e8d07395SMehdi Amini : parentOperation(std::move(parentOperation)), value(value) {} MlirValue()92078f2dae0SAlex Zinenko operator MlirValue() const { return value; } 921436c6c9cSStella Laurenzo get()922436c6c9cSStella Laurenzo MlirValue get() { return value; } getParentOperation()923436c6c9cSStella Laurenzo PyOperationRef &getParentOperation() { return parentOperation; } 924436c6c9cSStella Laurenzo checkValid()925436c6c9cSStella Laurenzo void checkValid() { return parentOperation->checkValid(); } 926436c6c9cSStella Laurenzo 9273f3d1c90SMike Urbach /// Gets a capsule wrapping the void* within the MlirValue. 9283f3d1c90SMike Urbach pybind11::object getCapsule(); 9293f3d1c90SMike Urbach 9303f3d1c90SMike Urbach /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 9313f3d1c90SMike Urbach /// the underlying MlirValue is still tied to the owning operation. 9323f3d1c90SMike Urbach static PyValue createFromCapsule(pybind11::object capsule); 9333f3d1c90SMike Urbach 934436c6c9cSStella Laurenzo private: 935436c6c9cSStella Laurenzo PyOperationRef parentOperation; 936436c6c9cSStella Laurenzo MlirValue value; 937436c6c9cSStella Laurenzo }; 938436c6c9cSStella Laurenzo 939436c6c9cSStella Laurenzo /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 940436c6c9cSStella Laurenzo class PyAffineExpr : public BaseContextObject { 941436c6c9cSStella Laurenzo public: PyAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)942436c6c9cSStella Laurenzo PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 943436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 944436c6c9cSStella Laurenzo bool operator==(const PyAffineExpr &other); MlirAffineExpr()945436c6c9cSStella Laurenzo operator MlirAffineExpr() const { return affineExpr; } get()946436c6c9cSStella Laurenzo MlirAffineExpr get() const { return affineExpr; } 947436c6c9cSStella Laurenzo 948436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAffineExpr. 949436c6c9cSStella Laurenzo pybind11::object getCapsule(); 950436c6c9cSStella Laurenzo 951436c6c9cSStella Laurenzo /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 952436c6c9cSStella Laurenzo /// Note that PyAffineExpr instances are uniqued, so the returned object 953436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 954436c6c9cSStella Laurenzo /// is taken by calling this function. 955436c6c9cSStella Laurenzo static PyAffineExpr createFromCapsule(pybind11::object capsule); 956436c6c9cSStella Laurenzo 957436c6c9cSStella Laurenzo PyAffineExpr add(const PyAffineExpr &other) const; 958436c6c9cSStella Laurenzo PyAffineExpr mul(const PyAffineExpr &other) const; 959436c6c9cSStella Laurenzo PyAffineExpr floorDiv(const PyAffineExpr &other) const; 960436c6c9cSStella Laurenzo PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 961436c6c9cSStella Laurenzo PyAffineExpr mod(const PyAffineExpr &other) const; 962436c6c9cSStella Laurenzo 963436c6c9cSStella Laurenzo private: 964436c6c9cSStella Laurenzo MlirAffineExpr affineExpr; 965436c6c9cSStella Laurenzo }; 966436c6c9cSStella Laurenzo 967436c6c9cSStella Laurenzo class PyAffineMap : public BaseContextObject { 968436c6c9cSStella Laurenzo public: PyAffineMap(PyMlirContextRef contextRef,MlirAffineMap affineMap)969436c6c9cSStella Laurenzo PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 970436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 971436c6c9cSStella Laurenzo bool operator==(const PyAffineMap &other); MlirAffineMap()972436c6c9cSStella Laurenzo operator MlirAffineMap() const { return affineMap; } get()973436c6c9cSStella Laurenzo MlirAffineMap get() const { return affineMap; } 974436c6c9cSStella Laurenzo 975436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirAffineMap. 976436c6c9cSStella Laurenzo pybind11::object getCapsule(); 977436c6c9cSStella Laurenzo 978436c6c9cSStella Laurenzo /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 979436c6c9cSStella Laurenzo /// Note that PyAffineMap instances are uniqued, so the returned object 980436c6c9cSStella Laurenzo /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 981436c6c9cSStella Laurenzo /// is taken by calling this function. 982436c6c9cSStella Laurenzo static PyAffineMap createFromCapsule(pybind11::object capsule); 983436c6c9cSStella Laurenzo 984436c6c9cSStella Laurenzo private: 985436c6c9cSStella Laurenzo MlirAffineMap affineMap; 986436c6c9cSStella Laurenzo }; 987436c6c9cSStella Laurenzo 988436c6c9cSStella Laurenzo class PyIntegerSet : public BaseContextObject { 989436c6c9cSStella Laurenzo public: PyIntegerSet(PyMlirContextRef contextRef,MlirIntegerSet integerSet)990436c6c9cSStella Laurenzo PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 991436c6c9cSStella Laurenzo : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 992436c6c9cSStella Laurenzo bool operator==(const PyIntegerSet &other); MlirIntegerSet()993436c6c9cSStella Laurenzo operator MlirIntegerSet() const { return integerSet; } get()994436c6c9cSStella Laurenzo MlirIntegerSet get() const { return integerSet; } 995436c6c9cSStella Laurenzo 996436c6c9cSStella Laurenzo /// Gets a capsule wrapping the void* within the MlirIntegerSet. 997436c6c9cSStella Laurenzo pybind11::object getCapsule(); 998436c6c9cSStella Laurenzo 999436c6c9cSStella Laurenzo /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 1000436c6c9cSStella Laurenzo /// Note that PyIntegerSet instances may be uniqued, so the returned object 1001436c6c9cSStella Laurenzo /// may be a pre-existing object. Integer sets are owned by the context. 1002436c6c9cSStella Laurenzo static PyIntegerSet createFromCapsule(pybind11::object capsule); 1003436c6c9cSStella Laurenzo 1004436c6c9cSStella Laurenzo private: 1005436c6c9cSStella Laurenzo MlirIntegerSet integerSet; 1006436c6c9cSStella Laurenzo }; 1007436c6c9cSStella Laurenzo 100830d61893SAlex Zinenko /// Bindings for MLIR symbol tables. 100930d61893SAlex Zinenko class PySymbolTable { 101030d61893SAlex Zinenko public: 101130d61893SAlex Zinenko /// Constructs a symbol table for the given operation. 101230d61893SAlex Zinenko explicit PySymbolTable(PyOperationBase &operation); 101330d61893SAlex Zinenko 101430d61893SAlex Zinenko /// Destroys the symbol table. ~PySymbolTable()101530d61893SAlex Zinenko ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 101630d61893SAlex Zinenko 101730d61893SAlex Zinenko /// Returns the symbol (opview) with the given name, throws if there is no 101830d61893SAlex Zinenko /// such symbol in the table. 101930d61893SAlex Zinenko pybind11::object dunderGetItem(const std::string &name); 102030d61893SAlex Zinenko 102130d61893SAlex Zinenko /// Removes the given operation from the symbol table and erases it. 102230d61893SAlex Zinenko void erase(PyOperationBase &symbol); 102330d61893SAlex Zinenko 102430d61893SAlex Zinenko /// Removes the operation with the given name from the symbol table and erases 102530d61893SAlex Zinenko /// it, throws if there is no such symbol in the table. 102630d61893SAlex Zinenko void dunderDel(const std::string &name); 102730d61893SAlex Zinenko 102830d61893SAlex Zinenko /// Inserts the given operation into the symbol table. The operation must have 102930d61893SAlex Zinenko /// the symbol trait. 103030d61893SAlex Zinenko PyAttribute insert(PyOperationBase &symbol); 103130d61893SAlex Zinenko 1032bdc31837SStella Laurenzo /// Gets and sets the name of a symbol op. 1033bdc31837SStella Laurenzo static PyAttribute getSymbolName(PyOperationBase &symbol); 1034bdc31837SStella Laurenzo static void setSymbolName(PyOperationBase &symbol, const std::string &name); 1035bdc31837SStella Laurenzo 1036bdc31837SStella Laurenzo /// Gets and sets the visibility of a symbol op. 1037bdc31837SStella Laurenzo static PyAttribute getVisibility(PyOperationBase &symbol); 1038bdc31837SStella Laurenzo static void setVisibility(PyOperationBase &symbol, 1039bdc31837SStella Laurenzo const std::string &visibility); 1040bdc31837SStella Laurenzo 1041bdc31837SStella Laurenzo /// Replaces all symbol uses within an operation. See the API 1042bdc31837SStella Laurenzo /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1043bdc31837SStella Laurenzo static void replaceAllSymbolUses(const std::string &oldSymbol, 1044bdc31837SStella Laurenzo const std::string &newSymbol, 1045bdc31837SStella Laurenzo PyOperationBase &from); 1046bdc31837SStella Laurenzo 1047bdc31837SStella Laurenzo /// Walks all symbol tables under and including 'from'. 1048bdc31837SStella Laurenzo static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1049bdc31837SStella Laurenzo pybind11::object callback); 1050bdc31837SStella Laurenzo 105130d61893SAlex Zinenko /// Casts the bindings class into the C API structure. MlirSymbolTable()105230d61893SAlex Zinenko operator MlirSymbolTable() { return symbolTable; } 105330d61893SAlex Zinenko 105430d61893SAlex Zinenko private: 105530d61893SAlex Zinenko PyOperationRef operation; 105630d61893SAlex Zinenko MlirSymbolTable symbolTable; 105730d61893SAlex Zinenko }; 105830d61893SAlex Zinenko 1059436c6c9cSStella Laurenzo void populateIRAffine(pybind11::module &m); 1060436c6c9cSStella Laurenzo void populateIRAttributes(pybind11::module &m); 1061436c6c9cSStella Laurenzo void populateIRCore(pybind11::module &m); 106214c92070SAlex Zinenko void populateIRInterfaces(pybind11::module &m); 1063436c6c9cSStella Laurenzo void populateIRTypes(pybind11::module &m); 1064436c6c9cSStella Laurenzo 1065436c6c9cSStella Laurenzo } // namespace python 1066436c6c9cSStella Laurenzo } // namespace mlir 1067436c6c9cSStella Laurenzo 1068436c6c9cSStella Laurenzo namespace pybind11 { 1069436c6c9cSStella Laurenzo namespace detail { 1070436c6c9cSStella Laurenzo 1071436c6c9cSStella Laurenzo template <> 1072436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyMlirContext> 1073436c6c9cSStella Laurenzo : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1074436c6c9cSStella Laurenzo template <> 1075436c6c9cSStella Laurenzo struct type_caster<mlir::python::DefaultingPyLocation> 1076436c6c9cSStella Laurenzo : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1077436c6c9cSStella Laurenzo 1078436c6c9cSStella Laurenzo } // namespace detail 1079436c6c9cSStella Laurenzo } // namespace pybind11 1080436c6c9cSStella Laurenzo 1081436c6c9cSStella Laurenzo #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1082