1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <utility> 13 #include <vector> 14 15 #include "PybindUtils.h" 16 17 #include "mlir-c/AffineExpr.h" 18 #include "mlir-c/AffineMap.h" 19 #include "mlir-c/Diagnostics.h" 20 #include "mlir-c/IR.h" 21 #include "mlir-c/IntegerSet.h" 22 #include "llvm/ADT/DenseMap.h" 23 #include "llvm/ADT/Optional.h" 24 25 namespace mlir { 26 namespace python { 27 28 class PyBlock; 29 class PyDiagnostic; 30 class PyDiagnosticHandler; 31 class PyInsertionPoint; 32 class PyLocation; 33 class DefaultingPyLocation; 34 class PyMlirContext; 35 class DefaultingPyMlirContext; 36 class PyModule; 37 class PyOperation; 38 class PyType; 39 class PySymbolTable; 40 class PyValue; 41 42 /// Template for a reference to a concrete type which captures a python 43 /// reference to its underlying python object. 44 template <typename T> 45 class PyObjectRef { 46 public: PyObjectRef(T * referrent,pybind11::object object)47 PyObjectRef(T *referrent, pybind11::object object) 48 : referrent(referrent), object(std::move(object)) { 49 assert(this->referrent && 50 "cannot construct PyObjectRef with null referrent"); 51 assert(this->object && "cannot construct PyObjectRef with null object"); 52 } PyObjectRef(PyObjectRef && other)53 PyObjectRef(PyObjectRef &&other) 54 : referrent(other.referrent), object(std::move(other.object)) { 55 other.referrent = nullptr; 56 assert(!other.object); 57 } PyObjectRef(const PyObjectRef & other)58 PyObjectRef(const PyObjectRef &other) 59 : referrent(other.referrent), object(other.object /* copies */) {} 60 ~PyObjectRef() = default; 61 getRefCount()62 int getRefCount() { 63 if (!object) 64 return 0; 65 return object.ref_count(); 66 } 67 68 /// Releases the object held by this instance, returning it. 69 /// This is the proper thing to return from a function that wants to return 70 /// the reference. Note that this does not work from initializers. releaseObject()71 pybind11::object releaseObject() { 72 assert(referrent && object); 73 referrent = nullptr; 74 auto stolen = std::move(object); 75 return stolen; 76 } 77 get()78 T *get() { return referrent; } 79 T *operator->() { 80 assert(referrent && object); 81 return referrent; 82 } getObject()83 pybind11::object getObject() { 84 assert(referrent && object); 85 return object; 86 } 87 operator bool() const { return referrent && object; } 88 89 private: 90 T *referrent; 91 pybind11::object object; 92 }; 93 94 /// Tracks an entry in the thread context stack. New entries are pushed onto 95 /// here for each with block that activates a new InsertionPoint, Context or 96 /// Location. 97 /// 98 /// Pushing either a Location or InsertionPoint also pushes its associated 99 /// Context. Pushing a Context will not modify the Location or InsertionPoint 100 /// unless if they are from a different context, in which case, they are 101 /// cleared. 102 class PyThreadContextEntry { 103 public: 104 enum class FrameKind { 105 Context, 106 InsertionPoint, 107 Location, 108 }; 109 PyThreadContextEntry(FrameKind frameKind,pybind11::object context,pybind11::object insertionPoint,pybind11::object location)110 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 111 pybind11::object insertionPoint, 112 pybind11::object location) 113 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 114 location(std::move(location)), frameKind(frameKind) {} 115 116 /// Gets the top of stack context and return nullptr if not defined. 117 static PyMlirContext *getDefaultContext(); 118 119 /// Gets the top of stack insertion point and return nullptr if not defined. 120 static PyInsertionPoint *getDefaultInsertionPoint(); 121 122 /// Gets the top of stack location and returns nullptr if not defined. 123 static PyLocation *getDefaultLocation(); 124 125 PyMlirContext *getContext(); 126 PyInsertionPoint *getInsertionPoint(); 127 PyLocation *getLocation(); getFrameKind()128 FrameKind getFrameKind() { return frameKind; } 129 130 /// Stack management. 131 static PyThreadContextEntry *getTopOfStack(); 132 static pybind11::object pushContext(PyMlirContext &context); 133 static void popContext(PyMlirContext &context); 134 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 135 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 136 static pybind11::object pushLocation(PyLocation &location); 137 static void popLocation(PyLocation &location); 138 139 /// Gets the thread local stack. 140 static std::vector<PyThreadContextEntry> &getStack(); 141 142 private: 143 static void push(FrameKind frameKind, pybind11::object context, 144 pybind11::object insertionPoint, pybind11::object location); 145 146 /// An object reference to the PyContext. 147 pybind11::object context; 148 /// An object reference to the current insertion point. 149 pybind11::object insertionPoint; 150 /// An object reference to the current location. 151 pybind11::object location; 152 // The kind of push that was performed. 153 FrameKind frameKind; 154 }; 155 156 /// Wrapper around MlirContext. 157 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 158 class PyMlirContext { 159 public: 160 PyMlirContext() = delete; 161 PyMlirContext(const PyMlirContext &) = delete; 162 PyMlirContext(PyMlirContext &&) = delete; 163 164 /// For the case of a python __init__ (py::init) method, pybind11 is quite 165 /// strict about needing to return a pointer that is not yet associated to 166 /// an py::object. Since the forContext() method acts like a pool, possibly 167 /// returning a recycled context, it does not satisfy this need. The usual 168 /// way in python to accomplish such a thing is to override __new__, but 169 /// that is also not supported by pybind11. Instead, we use this entry 170 /// point which always constructs a fresh context (which cannot alias an 171 /// existing one because it is fresh). 172 static PyMlirContext *createNewContextForInit(); 173 174 /// Returns a context reference for the singleton PyMlirContext wrapper for 175 /// the given context. 176 static PyMlirContextRef forContext(MlirContext context); 177 ~PyMlirContext(); 178 179 /// Accesses the underlying MlirContext. get()180 MlirContext get() { return context; } 181 182 /// Gets a strong reference to this context, which will ensure it is kept 183 /// alive for the life of the reference. getRef()184 PyMlirContextRef getRef() { 185 return PyMlirContextRef(this, pybind11::cast(this)); 186 } 187 188 /// Gets a capsule wrapping the void* within the MlirContext. 189 pybind11::object getCapsule(); 190 191 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 192 /// Note that PyMlirContext instances are uniqued, so the returned object 193 /// may be a pre-existing object. Ownership of the underlying MlirContext 194 /// is taken by calling this function. 195 static pybind11::object createFromCapsule(pybind11::object capsule); 196 197 /// Gets the count of live context objects. Used for testing. 198 static size_t getLiveCount(); 199 200 /// Gets the count of live operations associated with this context. 201 /// Used for testing. 202 size_t getLiveOperationCount(); 203 204 /// Clears the live operations map, returning the number of entries which were 205 /// invalidated. To be used as a safety mechanism so that API end-users can't 206 /// corrupt by holding references they shouldn't have accessed in the first 207 /// place. 208 size_t clearLiveOperations(); 209 210 /// Gets the count of live modules associated with this context. 211 /// Used for testing. 212 size_t getLiveModuleCount(); 213 214 /// Enter and exit the context manager. 215 pybind11::object contextEnter(); 216 void contextExit(const pybind11::object &excType, 217 const pybind11::object &excVal, 218 const pybind11::object &excTb); 219 220 /// Attaches a Python callback as a diagnostic handler, returning a 221 /// registration object (internally a PyDiagnosticHandler). 222 pybind11::object attachDiagnosticHandler(pybind11::object callback); 223 224 private: 225 PyMlirContext(MlirContext context); 226 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 227 // preserving the relationship that an MlirContext maps to a single 228 // PyMlirContext wrapper. This could be replaced in the future with an 229 // extension mechanism on the MlirContext for stashing user pointers. 230 // Note that this holds a handle, which does not imply ownership. 231 // Mappings will be removed when the context is destructed. 232 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 233 static LiveContextMap &getLiveContexts(); 234 235 // Interns all live modules associated with this context. Modules tracked 236 // in this map are valid. When a module is invalidated, it is removed 237 // from this map, and while it still exists as an instance, any 238 // attempt to access it will raise an error. 239 using LiveModuleMap = 240 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 241 LiveModuleMap liveModules; 242 243 // Interns all live operations associated with this context. Operations 244 // tracked in this map are valid. When an operation is invalidated, it is 245 // removed from this map, and while it still exists as an instance, any 246 // attempt to access it will raise an error. 247 using LiveOperationMap = 248 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 249 LiveOperationMap liveOperations; 250 251 MlirContext context; 252 friend class PyModule; 253 friend class PyOperation; 254 }; 255 256 /// Used in function arguments when None should resolve to the current context 257 /// manager set instance. 258 class DefaultingPyMlirContext 259 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 260 public: 261 using Defaulting::Defaulting; 262 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 263 static PyMlirContext &resolve(); 264 }; 265 266 /// Base class for all objects that directly or indirectly depend on an 267 /// MlirContext. The lifetime of the context will extend at least to the 268 /// lifetime of these instances. 269 /// Immutable objects that depend on a context extend this directly. 270 class BaseContextObject { 271 public: BaseContextObject(PyMlirContextRef ref)272 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 273 assert(this->contextRef && 274 "context object constructed with null context ref"); 275 } 276 277 /// Accesses the context reference. getContext()278 PyMlirContextRef &getContext() { return contextRef; } 279 280 private: 281 PyMlirContextRef contextRef; 282 }; 283 284 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 285 /// are only valid for the duration of a diagnostic callback and attempting 286 /// to access them outside of that will raise an exception. This applies to 287 /// nested diagnostics (in the notes) as well. 288 class PyDiagnostic { 289 public: PyDiagnostic(MlirDiagnostic diagnostic)290 PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 291 void invalidate(); isValid()292 bool isValid() { return valid; } 293 MlirDiagnosticSeverity getSeverity(); 294 PyLocation getLocation(); 295 pybind11::str getMessage(); 296 pybind11::tuple getNotes(); 297 298 private: 299 MlirDiagnostic diagnostic; 300 301 void checkValid(); 302 /// If notes have been materialized from the diagnostic, then this will 303 /// be populated with the corresponding objects (all castable to 304 /// PyDiagnostic). 305 llvm::Optional<pybind11::tuple> materializedNotes; 306 bool valid = true; 307 }; 308 309 /// Represents a diagnostic handler attached to the context. The handler's 310 /// callback will be invoked with PyDiagnostic instances until the detach() 311 /// method is called or the context is destroyed. A diagnostic handler can be 312 /// the subject of a `with` block, which will detach it when the block exits. 313 /// 314 /// Since diagnostic handlers can call back into Python code which can do 315 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 316 /// etc), this is generally not deemed to be a great user-level API. Users 317 /// should generally use some form of DiagnosticCollector. If the handler raises 318 /// any exceptions, they will just be emitted to stderr and dropped. 319 /// 320 /// The unique usage of this class means that its lifetime management is 321 /// different from most other parts of the API. Instances are always created 322 /// in an attached state and can transition to a detached state by either: 323 /// a) The context being destroyed and unregistering all handlers. 324 /// b) An explicit call to detach(). 325 /// The object may remain live from a Python perspective for an arbitrary time 326 /// after detachment, but there is nothing the user can do with it (since there 327 /// is no way to attach an existing handler object). 328 class PyDiagnosticHandler { 329 public: 330 PyDiagnosticHandler(MlirContext context, pybind11::object callback); 331 ~PyDiagnosticHandler(); 332 isAttached()333 bool isAttached() { return registeredID.hasValue(); } getHadError()334 bool getHadError() { return hadError; } 335 336 /// Detaches the handler. Does nothing if not attached. 337 void detach(); 338 contextEnter()339 pybind11::object contextEnter() { return pybind11::cast(this); } contextExit(const pybind11::object & excType,const pybind11::object & excVal,const pybind11::object & excTb)340 void contextExit(const pybind11::object &excType, 341 const pybind11::object &excVal, 342 const pybind11::object &excTb) { 343 detach(); 344 } 345 346 private: 347 MlirContext context; 348 pybind11::object callback; 349 llvm::Optional<MlirDiagnosticHandlerID> registeredID; 350 bool hadError = false; 351 friend class PyMlirContext; 352 }; 353 354 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 355 /// order to differentiate it from the `Dialect` base class which is extended by 356 /// plugins which extend dialect functionality through extension python code. 357 /// This should be seen as the "low-level" object and `Dialect` as the 358 /// high-level, user facing object. 359 class PyDialectDescriptor : public BaseContextObject { 360 public: PyDialectDescriptor(PyMlirContextRef contextRef,MlirDialect dialect)361 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 362 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 363 get()364 MlirDialect get() { return dialect; } 365 366 private: 367 MlirDialect dialect; 368 }; 369 370 /// User-level object for accessing dialects with dotted syntax such as: 371 /// ctx.dialect.std 372 class PyDialects : public BaseContextObject { 373 public: PyDialects(PyMlirContextRef contextRef)374 PyDialects(PyMlirContextRef contextRef) 375 : BaseContextObject(std::move(contextRef)) {} 376 377 MlirDialect getDialectForKey(const std::string &key, bool attrError); 378 }; 379 380 /// User-level dialect object. For dialects that have a registered extension, 381 /// this will be the base class of the extension dialect type. For un-extended, 382 /// objects of this type will be returned directly. 383 class PyDialect { 384 public: PyDialect(pybind11::object descriptor)385 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 386 getDescriptor()387 pybind11::object getDescriptor() { return descriptor; } 388 389 private: 390 pybind11::object descriptor; 391 }; 392 393 /// Wrapper around an MlirDialectRegistry. 394 /// Upon construction, the Python wrapper takes ownership of the 395 /// underlying MlirDialectRegistry. 396 class PyDialectRegistry { 397 public: PyDialectRegistry()398 PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} PyDialectRegistry(MlirDialectRegistry registry)399 PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} ~PyDialectRegistry()400 ~PyDialectRegistry() { 401 if (!mlirDialectRegistryIsNull(registry)) 402 mlirDialectRegistryDestroy(registry); 403 } 404 PyDialectRegistry(PyDialectRegistry &) = delete; PyDialectRegistry(PyDialectRegistry && other)405 PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { 406 other.registry = {nullptr}; 407 } 408 MlirDialectRegistry()409 operator MlirDialectRegistry() const { return registry; } get()410 MlirDialectRegistry get() const { return registry; } 411 412 pybind11::object getCapsule(); 413 static PyDialectRegistry createFromCapsule(pybind11::object capsule); 414 415 private: 416 MlirDialectRegistry registry; 417 }; 418 419 /// Wrapper around an MlirLocation. 420 class PyLocation : public BaseContextObject { 421 public: PyLocation(PyMlirContextRef contextRef,MlirLocation loc)422 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 423 : BaseContextObject(std::move(contextRef)), loc(loc) {} 424 MlirLocation()425 operator MlirLocation() const { return loc; } get()426 MlirLocation get() const { return loc; } 427 428 /// Enter and exit the context manager. 429 pybind11::object contextEnter(); 430 void contextExit(const pybind11::object &excType, 431 const pybind11::object &excVal, 432 const pybind11::object &excTb); 433 434 /// Gets a capsule wrapping the void* within the MlirLocation. 435 pybind11::object getCapsule(); 436 437 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 438 /// Note that PyLocation instances are uniqued, so the returned object 439 /// may be a pre-existing object. Ownership of the underlying MlirLocation 440 /// is taken by calling this function. 441 static PyLocation createFromCapsule(pybind11::object capsule); 442 443 private: 444 MlirLocation loc; 445 }; 446 447 /// Used in function arguments when None should resolve to the current context 448 /// manager set instance. 449 class DefaultingPyLocation 450 : public Defaulting<DefaultingPyLocation, PyLocation> { 451 public: 452 using Defaulting::Defaulting; 453 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 454 static PyLocation &resolve(); 455 MlirLocation()456 operator MlirLocation() const { return *get(); } 457 }; 458 459 /// Wrapper around MlirModule. 460 /// This is the top-level, user-owned object that contains regions/ops/blocks. 461 class PyModule; 462 using PyModuleRef = PyObjectRef<PyModule>; 463 class PyModule : public BaseContextObject { 464 public: 465 /// Returns a PyModule reference for the given MlirModule. This may return 466 /// a pre-existing or new object. 467 static PyModuleRef forModule(MlirModule module); 468 PyModule(PyModule &) = delete; 469 PyModule(PyMlirContext &&) = delete; 470 ~PyModule(); 471 472 /// Gets the backing MlirModule. get()473 MlirModule get() { return module; } 474 475 /// Gets a strong reference to this module. getRef()476 PyModuleRef getRef() { 477 return PyModuleRef(this, 478 pybind11::reinterpret_borrow<pybind11::object>(handle)); 479 } 480 481 /// Gets a capsule wrapping the void* within the MlirModule. 482 /// Note that the module does not (yet) provide a corresponding factory for 483 /// constructing from a capsule as that would require uniquing PyModule 484 /// instances, which is not currently done. 485 pybind11::object getCapsule(); 486 487 /// Creates a PyModule from the MlirModule wrapped by a capsule. 488 /// Note that PyModule instances are uniqued, so the returned object 489 /// may be a pre-existing object. Ownership of the underlying MlirModule 490 /// is taken by calling this function. 491 static pybind11::object createFromCapsule(pybind11::object capsule); 492 493 private: 494 PyModule(PyMlirContextRef contextRef, MlirModule module); 495 MlirModule module; 496 pybind11::handle handle; 497 }; 498 499 /// Base class for PyOperation and PyOpView which exposes the primary, user 500 /// visible methods for manipulating it. 501 class PyOperationBase { 502 public: 503 virtual ~PyOperationBase() = default; 504 /// Implements the bound 'print' method and helps with others. 505 void print(pybind11::object fileObject, bool binary, 506 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 507 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 508 bool assumeVerified); 509 pybind11::object getAsm(bool binary, 510 llvm::Optional<int64_t> largeElementsLimit, 511 bool enableDebugInfo, bool prettyDebugInfo, 512 bool printGenericOpForm, bool useLocalScope, 513 bool assumeVerified); 514 515 /// Moves the operation before or after the other operation. 516 void moveAfter(PyOperationBase &other); 517 void moveBefore(PyOperationBase &other); 518 519 /// Each must provide access to the raw Operation. 520 virtual PyOperation &getOperation() = 0; 521 }; 522 523 /// Wrapper around PyOperation. 524 /// Operations exist in either an attached (dependent) or detached (top-level) 525 /// state. In the detached state (as on creation), an operation is owned by 526 /// the creator and its lifetime extends either until its reference count 527 /// drops to zero or it is attached to a parent, at which point its lifetime 528 /// is bounded by its top-level parent reference. 529 class PyOperation; 530 using PyOperationRef = PyObjectRef<PyOperation>; 531 class PyOperation : public PyOperationBase, public BaseContextObject { 532 public: 533 ~PyOperation() override; getOperation()534 PyOperation &getOperation() override { return *this; } 535 536 /// Returns a PyOperation for the given MlirOperation, optionally associating 537 /// it with a parentKeepAlive. 538 static PyOperationRef 539 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 540 pybind11::object parentKeepAlive = pybind11::object()); 541 542 /// Creates a detached operation. The operation must not be associated with 543 /// any existing live operation. 544 static PyOperationRef 545 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 546 pybind11::object parentKeepAlive = pybind11::object()); 547 548 /// Detaches the operation from its parent block and updates its state 549 /// accordingly. detachFromParent()550 void detachFromParent() { 551 mlirOperationRemoveFromParent(getOperation()); 552 setDetached(); 553 parentKeepAlive = pybind11::object(); 554 } 555 556 /// Gets the backing operation. MlirOperation()557 operator MlirOperation() const { return get(); } get()558 MlirOperation get() const { 559 checkValid(); 560 return operation; 561 } 562 getRef()563 PyOperationRef getRef() { 564 return PyOperationRef( 565 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 566 } 567 isAttached()568 bool isAttached() { return attached; } 569 void setAttached(const pybind11::object &parent = pybind11::object()) { 570 assert(!attached && "operation already attached"); 571 attached = true; 572 } setDetached()573 void setDetached() { 574 assert(attached && "operation already detached"); 575 attached = false; 576 } 577 void checkValid() const; 578 579 /// Gets the owning block or raises an exception if the operation has no 580 /// owning block. 581 PyBlock getBlock(); 582 583 /// Gets the parent operation or raises an exception if the operation has 584 /// no parent. 585 llvm::Optional<PyOperationRef> getParentOperation(); 586 587 /// Gets a capsule wrapping the void* within the MlirOperation. 588 pybind11::object getCapsule(); 589 590 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 591 /// Ownership of the underlying MlirOperation is taken by calling this 592 /// function. 593 static pybind11::object createFromCapsule(pybind11::object capsule); 594 595 /// Creates an operation. See corresponding python docstring. 596 static pybind11::object 597 create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 598 llvm::Optional<std::vector<PyValue *>> operands, 599 llvm::Optional<pybind11::dict> attributes, 600 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 601 DefaultingPyLocation location, const pybind11::object &ip); 602 603 /// Creates an OpView suitable for this operation. 604 pybind11::object createOpView(); 605 606 /// Erases the underlying MlirOperation, removes its pointer from the 607 /// parent context's live operations map, and sets the valid bit false. 608 void erase(); 609 610 /// Invalidate the operation. setInvalid()611 void setInvalid() { valid = false; } 612 613 /// Clones this operation. 614 pybind11::object clone(const pybind11::object &ip); 615 616 private: 617 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 618 static PyOperationRef createInstance(PyMlirContextRef contextRef, 619 MlirOperation operation, 620 pybind11::object parentKeepAlive); 621 622 MlirOperation operation; 623 pybind11::handle handle; 624 // Keeps the parent alive, regardless of whether it is an Operation or 625 // Module. 626 // TODO: As implemented, this facility is only sufficient for modeling the 627 // trivial module parent back-reference. Generalize this to also account for 628 // transitions from detached to attached and address TODOs in the 629 // ir_operation.py regarding testing corresponding lifetime guarantees. 630 pybind11::object parentKeepAlive; 631 bool attached = true; 632 bool valid = true; 633 634 friend class PyOperationBase; 635 friend class PySymbolTable; 636 }; 637 638 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 639 /// providing more instance-specific accessors and serve as the base class for 640 /// custom ODS-style operation classes. Since this class is subclass on the 641 /// python side, it must present an __init__ method that operates in pure 642 /// python types. 643 class PyOpView : public PyOperationBase { 644 public: 645 PyOpView(const pybind11::object &operationObject); getOperation()646 PyOperation &getOperation() override { return operation; } 647 648 static pybind11::object createRawSubclass(const pybind11::object &userClass); 649 getOperationObject()650 pybind11::object getOperationObject() { return operationObject; } 651 652 static pybind11::object 653 buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 654 pybind11::list operandList, 655 llvm::Optional<pybind11::dict> attributes, 656 llvm::Optional<std::vector<PyBlock *>> successors, 657 llvm::Optional<int> regions, DefaultingPyLocation location, 658 const pybind11::object &maybeIp); 659 660 private: 661 PyOperation &operation; // For efficient, cast-free access from C++ 662 pybind11::object operationObject; // Holds the reference. 663 }; 664 665 /// Wrapper around an MlirRegion. 666 /// Regions are managed completely by their containing operation. Unlike the 667 /// C++ API, the python API does not support detached regions. 668 class PyRegion { 669 public: PyRegion(PyOperationRef parentOperation,MlirRegion region)670 PyRegion(PyOperationRef parentOperation, MlirRegion region) 671 : parentOperation(std::move(parentOperation)), region(region) { 672 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 673 } MlirRegion()674 operator MlirRegion() const { return region; } 675 get()676 MlirRegion get() { return region; } getParentOperation()677 PyOperationRef &getParentOperation() { return parentOperation; } 678 checkValid()679 void checkValid() { return parentOperation->checkValid(); } 680 681 private: 682 PyOperationRef parentOperation; 683 MlirRegion region; 684 }; 685 686 /// Wrapper around an MlirBlock. 687 /// Blocks are managed completely by their containing operation. Unlike the 688 /// C++ API, the python API does not support detached blocks. 689 class PyBlock { 690 public: PyBlock(PyOperationRef parentOperation,MlirBlock block)691 PyBlock(PyOperationRef parentOperation, MlirBlock block) 692 : parentOperation(std::move(parentOperation)), block(block) { 693 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 694 } 695 get()696 MlirBlock get() { return block; } getParentOperation()697 PyOperationRef &getParentOperation() { return parentOperation; } 698 checkValid()699 void checkValid() { return parentOperation->checkValid(); } 700 701 private: 702 PyOperationRef parentOperation; 703 MlirBlock block; 704 }; 705 706 /// An insertion point maintains a pointer to a Block and a reference operation. 707 /// Calls to insert() will insert a new operation before the 708 /// reference operation. If the reference operation is null, then appends to 709 /// the end of the block. 710 class PyInsertionPoint { 711 public: 712 /// Creates an insertion point positioned after the last operation in the 713 /// block, but still inside the block. 714 PyInsertionPoint(PyBlock &block); 715 /// Creates an insertion point positioned before a reference operation. 716 PyInsertionPoint(PyOperationBase &beforeOperationBase); 717 718 /// Shortcut to create an insertion point at the beginning of the block. 719 static PyInsertionPoint atBlockBegin(PyBlock &block); 720 /// Shortcut to create an insertion point before the block terminator. 721 static PyInsertionPoint atBlockTerminator(PyBlock &block); 722 723 /// Inserts an operation. 724 void insert(PyOperationBase &operationBase); 725 726 /// Enter and exit the context manager. 727 pybind11::object contextEnter(); 728 void contextExit(const pybind11::object &excType, 729 const pybind11::object &excVal, 730 const pybind11::object &excTb); 731 getBlock()732 PyBlock &getBlock() { return block; } 733 734 private: 735 // Trampoline constructor that avoids null initializing members while 736 // looking up parents. PyInsertionPoint(PyBlock block,llvm::Optional<PyOperationRef> refOperation)737 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 738 : refOperation(std::move(refOperation)), block(std::move(block)) {} 739 740 llvm::Optional<PyOperationRef> refOperation; 741 PyBlock block; 742 }; 743 /// Wrapper around the generic MlirType. 744 /// The lifetime of a type is bound by the PyContext that created it. 745 class PyType : public BaseContextObject { 746 public: PyType(PyMlirContextRef contextRef,MlirType type)747 PyType(PyMlirContextRef contextRef, MlirType type) 748 : BaseContextObject(std::move(contextRef)), type(type) {} 749 bool operator==(const PyType &other); MlirType()750 operator MlirType() const { return type; } get()751 MlirType get() const { return type; } 752 753 /// Gets a capsule wrapping the void* within the MlirType. 754 pybind11::object getCapsule(); 755 756 /// Creates a PyType from the MlirType wrapped by a capsule. 757 /// Note that PyType instances are uniqued, so the returned object 758 /// may be a pre-existing object. Ownership of the underlying MlirType 759 /// is taken by calling this function. 760 static PyType createFromCapsule(pybind11::object capsule); 761 762 private: 763 MlirType type; 764 }; 765 766 /// CRTP base classes for Python types that subclass Type and should be 767 /// castable from it (i.e. via something like IntegerType(t)). 768 /// By default, type class hierarchies are one level deep (i.e. a 769 /// concrete type class extends PyType); however, intermediate python-visible 770 /// base classes can be modeled by specifying a BaseTy. 771 template <typename DerivedTy, typename BaseTy = PyType> 772 class PyConcreteType : public BaseTy { 773 public: 774 // Derived classes must define statics for: 775 // IsAFunctionTy isaFunction 776 // const char *pyClassName 777 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 778 using IsAFunctionTy = bool (*)(MlirType); 779 780 PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef,MlirType t)781 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 782 : BaseTy(std::move(contextRef), t) {} PyConcreteType(PyType & orig)783 PyConcreteType(PyType &orig) 784 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 785 castFrom(PyType & orig)786 static MlirType castFrom(PyType &orig) { 787 if (!DerivedTy::isaFunction(orig)) { 788 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 789 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 790 DerivedTy::pyClassName + 791 " (from " + origRepr + ")"); 792 } 793 return orig; 794 } 795 bind(pybind11::module & m)796 static void bind(pybind11::module &m) { 797 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 798 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 799 pybind11::arg("cast_from_type")); 800 cls.def_static( 801 "isinstance", 802 [](PyType &otherType) -> bool { 803 return DerivedTy::isaFunction(otherType); 804 }, 805 pybind11::arg("other")); 806 DerivedTy::bindDerived(cls); 807 } 808 809 /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)810 static void bindDerived(ClassTy &m) {} 811 }; 812 813 /// Wrapper around the generic MlirAttribute. 814 /// The lifetime of a type is bound by the PyContext that created it. 815 class PyAttribute : public BaseContextObject { 816 public: PyAttribute(PyMlirContextRef contextRef,MlirAttribute attr)817 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 818 : BaseContextObject(std::move(contextRef)), attr(attr) {} 819 bool operator==(const PyAttribute &other); MlirAttribute()820 operator MlirAttribute() const { return attr; } get()821 MlirAttribute get() const { return attr; } 822 823 /// Gets a capsule wrapping the void* within the MlirAttribute. 824 pybind11::object getCapsule(); 825 826 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 827 /// Note that PyAttribute instances are uniqued, so the returned object 828 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 829 /// is taken by calling this function. 830 static PyAttribute createFromCapsule(pybind11::object capsule); 831 832 private: 833 MlirAttribute attr; 834 }; 835 836 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 837 /// TODO: Refactor this and the C-API to be based on an Identifier owned 838 /// by the context so as to avoid ownership issues here. 839 class PyNamedAttribute { 840 public: 841 /// Constructs a PyNamedAttr that retains an owned name. This should be 842 /// used in any code that originates an MlirNamedAttribute from a python 843 /// string. 844 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 845 /// passed attribute. 846 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 847 848 MlirNamedAttribute namedAttr; 849 850 private: 851 // Since the MlirNamedAttr contains an internal pointer to the actual 852 // memory of the owned string, it must be heap allocated to remain valid. 853 // Otherwise, strings that fit within the small object optimization threshold 854 // will have their memory address change as the containing object is moved, 855 // resulting in an invalid aliased pointer. 856 std::unique_ptr<std::string> ownedName; 857 }; 858 859 /// CRTP base classes for Python attributes that subclass Attribute and should 860 /// be castable from it (i.e. via something like StringAttr(attr)). 861 /// By default, attribute class hierarchies are one level deep (i.e. a 862 /// concrete attribute class extends PyAttribute); however, intermediate 863 /// python-visible base classes can be modeled by specifying a BaseTy. 864 template <typename DerivedTy, typename BaseTy = PyAttribute> 865 class PyConcreteAttribute : public BaseTy { 866 public: 867 // Derived classes must define statics for: 868 // IsAFunctionTy isaFunction 869 // const char *pyClassName 870 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 871 using IsAFunctionTy = bool (*)(MlirAttribute); 872 873 PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)874 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 875 : BaseTy(std::move(contextRef), attr) {} PyConcreteAttribute(PyAttribute & orig)876 PyConcreteAttribute(PyAttribute &orig) 877 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 878 castFrom(PyAttribute & orig)879 static MlirAttribute castFrom(PyAttribute &orig) { 880 if (!DerivedTy::isaFunction(orig)) { 881 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 882 throw SetPyError(PyExc_ValueError, 883 llvm::Twine("Cannot cast attribute to ") + 884 DerivedTy::pyClassName + " (from " + origRepr + ")"); 885 } 886 return orig; 887 } 888 bind(pybind11::module & m)889 static void bind(pybind11::module &m) { 890 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 891 pybind11::module_local()); 892 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 893 pybind11::arg("cast_from_attr")); 894 cls.def_static( 895 "isinstance", 896 [](PyAttribute &otherAttr) -> bool { 897 return DerivedTy::isaFunction(otherAttr); 898 }, 899 pybind11::arg("other")); 900 cls.def_property_readonly("type", [](PyAttribute &attr) { 901 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 902 }); 903 DerivedTy::bindDerived(cls); 904 } 905 906 /// Implemented by derived classes to add methods to the Python subclass. bindDerived(ClassTy & m)907 static void bindDerived(ClassTy &m) {} 908 }; 909 910 /// Wrapper around the generic MlirValue. 911 /// Values are managed completely by the operation that resulted in their 912 /// definition. For op result value, this is the operation that defines the 913 /// value. For block argument values, this is the operation that contains the 914 /// block to which the value is an argument (blocks cannot be detached in Python 915 /// bindings so such operation always exists). 916 class PyValue { 917 public: PyValue(PyOperationRef parentOperation,MlirValue value)918 PyValue(PyOperationRef parentOperation, MlirValue value) 919 : parentOperation(std::move(parentOperation)), value(value) {} MlirValue()920 operator MlirValue() const { return value; } 921 get()922 MlirValue get() { return value; } getParentOperation()923 PyOperationRef &getParentOperation() { return parentOperation; } 924 checkValid()925 void checkValid() { return parentOperation->checkValid(); } 926 927 /// Gets a capsule wrapping the void* within the MlirValue. 928 pybind11::object getCapsule(); 929 930 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 931 /// the underlying MlirValue is still tied to the owning operation. 932 static PyValue createFromCapsule(pybind11::object capsule); 933 934 private: 935 PyOperationRef parentOperation; 936 MlirValue value; 937 }; 938 939 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 940 class PyAffineExpr : public BaseContextObject { 941 public: PyAffineExpr(PyMlirContextRef contextRef,MlirAffineExpr affineExpr)942 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 943 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 944 bool operator==(const PyAffineExpr &other); MlirAffineExpr()945 operator MlirAffineExpr() const { return affineExpr; } get()946 MlirAffineExpr get() const { return affineExpr; } 947 948 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 949 pybind11::object getCapsule(); 950 951 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 952 /// Note that PyAffineExpr instances are uniqued, so the returned object 953 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 954 /// is taken by calling this function. 955 static PyAffineExpr createFromCapsule(pybind11::object capsule); 956 957 PyAffineExpr add(const PyAffineExpr &other) const; 958 PyAffineExpr mul(const PyAffineExpr &other) const; 959 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 960 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 961 PyAffineExpr mod(const PyAffineExpr &other) const; 962 963 private: 964 MlirAffineExpr affineExpr; 965 }; 966 967 class PyAffineMap : public BaseContextObject { 968 public: PyAffineMap(PyMlirContextRef contextRef,MlirAffineMap affineMap)969 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 970 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 971 bool operator==(const PyAffineMap &other); MlirAffineMap()972 operator MlirAffineMap() const { return affineMap; } get()973 MlirAffineMap get() const { return affineMap; } 974 975 /// Gets a capsule wrapping the void* within the MlirAffineMap. 976 pybind11::object getCapsule(); 977 978 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 979 /// Note that PyAffineMap instances are uniqued, so the returned object 980 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 981 /// is taken by calling this function. 982 static PyAffineMap createFromCapsule(pybind11::object capsule); 983 984 private: 985 MlirAffineMap affineMap; 986 }; 987 988 class PyIntegerSet : public BaseContextObject { 989 public: PyIntegerSet(PyMlirContextRef contextRef,MlirIntegerSet integerSet)990 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 991 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 992 bool operator==(const PyIntegerSet &other); MlirIntegerSet()993 operator MlirIntegerSet() const { return integerSet; } get()994 MlirIntegerSet get() const { return integerSet; } 995 996 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 997 pybind11::object getCapsule(); 998 999 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 1000 /// Note that PyIntegerSet instances may be uniqued, so the returned object 1001 /// may be a pre-existing object. Integer sets are owned by the context. 1002 static PyIntegerSet createFromCapsule(pybind11::object capsule); 1003 1004 private: 1005 MlirIntegerSet integerSet; 1006 }; 1007 1008 /// Bindings for MLIR symbol tables. 1009 class PySymbolTable { 1010 public: 1011 /// Constructs a symbol table for the given operation. 1012 explicit PySymbolTable(PyOperationBase &operation); 1013 1014 /// Destroys the symbol table. ~PySymbolTable()1015 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 1016 1017 /// Returns the symbol (opview) with the given name, throws if there is no 1018 /// such symbol in the table. 1019 pybind11::object dunderGetItem(const std::string &name); 1020 1021 /// Removes the given operation from the symbol table and erases it. 1022 void erase(PyOperationBase &symbol); 1023 1024 /// Removes the operation with the given name from the symbol table and erases 1025 /// it, throws if there is no such symbol in the table. 1026 void dunderDel(const std::string &name); 1027 1028 /// Inserts the given operation into the symbol table. The operation must have 1029 /// the symbol trait. 1030 PyAttribute insert(PyOperationBase &symbol); 1031 1032 /// Gets and sets the name of a symbol op. 1033 static PyAttribute getSymbolName(PyOperationBase &symbol); 1034 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 1035 1036 /// Gets and sets the visibility of a symbol op. 1037 static PyAttribute getVisibility(PyOperationBase &symbol); 1038 static void setVisibility(PyOperationBase &symbol, 1039 const std::string &visibility); 1040 1041 /// Replaces all symbol uses within an operation. See the API 1042 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1043 static void replaceAllSymbolUses(const std::string &oldSymbol, 1044 const std::string &newSymbol, 1045 PyOperationBase &from); 1046 1047 /// Walks all symbol tables under and including 'from'. 1048 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1049 pybind11::object callback); 1050 1051 /// Casts the bindings class into the C API structure. MlirSymbolTable()1052 operator MlirSymbolTable() { return symbolTable; } 1053 1054 private: 1055 PyOperationRef operation; 1056 MlirSymbolTable symbolTable; 1057 }; 1058 1059 void populateIRAffine(pybind11::module &m); 1060 void populateIRAttributes(pybind11::module &m); 1061 void populateIRCore(pybind11::module &m); 1062 void populateIRInterfaces(pybind11::module &m); 1063 void populateIRTypes(pybind11::module &m); 1064 1065 } // namespace python 1066 } // namespace mlir 1067 1068 namespace pybind11 { 1069 namespace detail { 1070 1071 template <> 1072 struct type_caster<mlir::python::DefaultingPyMlirContext> 1073 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1074 template <> 1075 struct type_caster<mlir::python::DefaultingPyLocation> 1076 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1077 1078 } // namespace detail 1079 } // namespace pybind11 1080 1081 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1082