1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <vector> 13 14 #include "PybindUtils.h" 15 16 #include "mlir-c/AffineExpr.h" 17 #include "mlir-c/AffineMap.h" 18 #include "mlir-c/Diagnostics.h" 19 #include "mlir-c/IR.h" 20 #include "mlir-c/IntegerSet.h" 21 #include "llvm/ADT/DenseMap.h" 22 #include "llvm/ADT/Optional.h" 23 24 namespace mlir { 25 namespace python { 26 27 class PyBlock; 28 class PyDiagnostic; 29 class PyDiagnosticHandler; 30 class PyInsertionPoint; 31 class PyLocation; 32 class DefaultingPyLocation; 33 class PyMlirContext; 34 class DefaultingPyMlirContext; 35 class PyModule; 36 class PyOperation; 37 class PyType; 38 class PySymbolTable; 39 class PyValue; 40 41 /// Template for a reference to a concrete type which captures a python 42 /// reference to its underlying python object. 43 template <typename T> 44 class PyObjectRef { 45 public: 46 PyObjectRef(T *referrent, pybind11::object object) 47 : referrent(referrent), object(std::move(object)) { 48 assert(this->referrent && 49 "cannot construct PyObjectRef with null referrent"); 50 assert(this->object && "cannot construct PyObjectRef with null object"); 51 } 52 PyObjectRef(PyObjectRef &&other) 53 : referrent(other.referrent), object(std::move(other.object)) { 54 other.referrent = nullptr; 55 assert(!other.object); 56 } 57 PyObjectRef(const PyObjectRef &other) 58 : referrent(other.referrent), object(other.object /* copies */) {} 59 ~PyObjectRef() {} 60 61 int getRefCount() { 62 if (!object) 63 return 0; 64 return object.ref_count(); 65 } 66 67 /// Releases the object held by this instance, returning it. 68 /// This is the proper thing to return from a function that wants to return 69 /// the reference. Note that this does not work from initializers. 70 pybind11::object releaseObject() { 71 assert(referrent && object); 72 referrent = nullptr; 73 auto stolen = std::move(object); 74 return stolen; 75 } 76 77 T *get() { return referrent; } 78 T *operator->() { 79 assert(referrent && object); 80 return referrent; 81 } 82 pybind11::object getObject() { 83 assert(referrent && object); 84 return object; 85 } 86 operator bool() const { return referrent && object; } 87 88 private: 89 T *referrent; 90 pybind11::object object; 91 }; 92 93 /// Tracks an entry in the thread context stack. New entries are pushed onto 94 /// here for each with block that activates a new InsertionPoint, Context or 95 /// Location. 96 /// 97 /// Pushing either a Location or InsertionPoint also pushes its associated 98 /// Context. Pushing a Context will not modify the Location or InsertionPoint 99 /// unless if they are from a different context, in which case, they are 100 /// cleared. 101 class PyThreadContextEntry { 102 public: 103 enum class FrameKind { 104 Context, 105 InsertionPoint, 106 Location, 107 }; 108 109 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 110 pybind11::object insertionPoint, 111 pybind11::object location) 112 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 113 location(std::move(location)), frameKind(frameKind) {} 114 115 /// Gets the top of stack context and return nullptr if not defined. 116 static PyMlirContext *getDefaultContext(); 117 118 /// Gets the top of stack insertion point and return nullptr if not defined. 119 static PyInsertionPoint *getDefaultInsertionPoint(); 120 121 /// Gets the top of stack location and returns nullptr if not defined. 122 static PyLocation *getDefaultLocation(); 123 124 PyMlirContext *getContext(); 125 PyInsertionPoint *getInsertionPoint(); 126 PyLocation *getLocation(); 127 FrameKind getFrameKind() { return frameKind; } 128 129 /// Stack management. 130 static PyThreadContextEntry *getTopOfStack(); 131 static pybind11::object pushContext(PyMlirContext &context); 132 static void popContext(PyMlirContext &context); 133 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 134 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 135 static pybind11::object pushLocation(PyLocation &location); 136 static void popLocation(PyLocation &location); 137 138 /// Gets the thread local stack. 139 static std::vector<PyThreadContextEntry> &getStack(); 140 141 private: 142 static void push(FrameKind frameKind, pybind11::object context, 143 pybind11::object insertionPoint, pybind11::object location); 144 145 /// An object reference to the PyContext. 146 pybind11::object context; 147 /// An object reference to the current insertion point. 148 pybind11::object insertionPoint; 149 /// An object reference to the current location. 150 pybind11::object location; 151 // The kind of push that was performed. 152 FrameKind frameKind; 153 }; 154 155 /// Wrapper around MlirContext. 156 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 157 class PyMlirContext { 158 public: 159 PyMlirContext() = delete; 160 PyMlirContext(const PyMlirContext &) = delete; 161 PyMlirContext(PyMlirContext &&) = delete; 162 163 /// For the case of a python __init__ (py::init) method, pybind11 is quite 164 /// strict about needing to return a pointer that is not yet associated to 165 /// an py::object. Since the forContext() method acts like a pool, possibly 166 /// returning a recycled context, it does not satisfy this need. The usual 167 /// way in python to accomplish such a thing is to override __new__, but 168 /// that is also not supported by pybind11. Instead, we use this entry 169 /// point which always constructs a fresh context (which cannot alias an 170 /// existing one because it is fresh). 171 static PyMlirContext *createNewContextForInit(); 172 173 /// Returns a context reference for the singleton PyMlirContext wrapper for 174 /// the given context. 175 static PyMlirContextRef forContext(MlirContext context); 176 ~PyMlirContext(); 177 178 /// Accesses the underlying MlirContext. 179 MlirContext get() { return context; } 180 181 /// Gets a strong reference to this context, which will ensure it is kept 182 /// alive for the life of the reference. 183 PyMlirContextRef getRef() { 184 return PyMlirContextRef(this, pybind11::cast(this)); 185 } 186 187 /// Gets a capsule wrapping the void* within the MlirContext. 188 pybind11::object getCapsule(); 189 190 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 191 /// Note that PyMlirContext instances are uniqued, so the returned object 192 /// may be a pre-existing object. Ownership of the underlying MlirContext 193 /// is taken by calling this function. 194 static pybind11::object createFromCapsule(pybind11::object capsule); 195 196 /// Gets the count of live context objects. Used for testing. 197 static size_t getLiveCount(); 198 199 /// Gets the count of live operations associated with this context. 200 /// Used for testing. 201 size_t getLiveOperationCount(); 202 203 /// Gets the count of live modules associated with this context. 204 /// Used for testing. 205 size_t getLiveModuleCount(); 206 207 /// Enter and exit the context manager. 208 pybind11::object contextEnter(); 209 void contextExit(const pybind11::object &excType, 210 const pybind11::object &excVal, 211 const pybind11::object &excTb); 212 213 /// Attaches a Python callback as a diagnostic handler, returning a 214 /// registration object (internally a PyDiagnosticHandler). 215 pybind11::object attachDiagnosticHandler(pybind11::object callback); 216 217 private: 218 PyMlirContext(MlirContext context); 219 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 220 // preserving the relationship that an MlirContext maps to a single 221 // PyMlirContext wrapper. This could be replaced in the future with an 222 // extension mechanism on the MlirContext for stashing user pointers. 223 // Note that this holds a handle, which does not imply ownership. 224 // Mappings will be removed when the context is destructed. 225 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 226 static LiveContextMap &getLiveContexts(); 227 228 // Interns all live modules associated with this context. Modules tracked 229 // in this map are valid. When a module is invalidated, it is removed 230 // from this map, and while it still exists as an instance, any 231 // attempt to access it will raise an error. 232 using LiveModuleMap = 233 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 234 LiveModuleMap liveModules; 235 236 // Interns all live operations associated with this context. Operations 237 // tracked in this map are valid. When an operation is invalidated, it is 238 // removed from this map, and while it still exists as an instance, any 239 // attempt to access it will raise an error. 240 using LiveOperationMap = 241 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 242 LiveOperationMap liveOperations; 243 244 MlirContext context; 245 friend class PyModule; 246 friend class PyOperation; 247 }; 248 249 /// Used in function arguments when None should resolve to the current context 250 /// manager set instance. 251 class DefaultingPyMlirContext 252 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 253 public: 254 using Defaulting::Defaulting; 255 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 256 static PyMlirContext &resolve(); 257 }; 258 259 /// Base class for all objects that directly or indirectly depend on an 260 /// MlirContext. The lifetime of the context will extend at least to the 261 /// lifetime of these instances. 262 /// Immutable objects that depend on a context extend this directly. 263 class BaseContextObject { 264 public: 265 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 266 assert(this->contextRef && 267 "context object constructed with null context ref"); 268 } 269 270 /// Accesses the context reference. 271 PyMlirContextRef &getContext() { return contextRef; } 272 273 private: 274 PyMlirContextRef contextRef; 275 }; 276 277 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 278 /// are only valid for the duration of a diagnostic callback and attempting 279 /// to access them outside of that will raise an exception. This applies to 280 /// nested diagnostics (in the notes) as well. 281 class PyDiagnostic { 282 public: 283 PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 284 void invalidate(); 285 bool isValid() { return valid; } 286 MlirDiagnosticSeverity getSeverity(); 287 PyLocation getLocation(); 288 pybind11::str getMessage(); 289 pybind11::tuple getNotes(); 290 291 private: 292 MlirDiagnostic diagnostic; 293 294 void checkValid(); 295 /// If notes have been materialized from the diagnostic, then this will 296 /// be populated with the corresponding objects (all castable to 297 /// PyDiagnostic). 298 llvm::Optional<pybind11::tuple> materializedNotes; 299 bool valid = true; 300 }; 301 302 /// Represents a diagnostic handler attached to the context. The handler's 303 /// callback will be invoked with PyDiagnostic instances until the detach() 304 /// method is called or the context is destroyed. A diagnostic handler can be 305 /// the subject of a `with` block, which will detach it when the block exits. 306 /// 307 /// Since diagnostic handlers can call back into Python code which can do 308 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 309 /// etc), this is generally not deemed to be a great user-level API. Users 310 /// should generally use some form of DiagnosticCollector. If the handler raises 311 /// any exceptions, they will just be emitted to stderr and dropped. 312 /// 313 /// The unique usage of this class means that its lifetime management is 314 /// different from most other parts of the API. Instances are always created 315 /// in an attached state and can transition to a detached state by either: 316 /// a) The context being destroyed and unregistering all handlers. 317 /// b) An explicit call to detach(). 318 /// The object may remain live from a Python perspective for an arbitrary time 319 /// after detachment, but there is nothing the user can do with it (since there 320 /// is no way to attach an existing handler object). 321 class PyDiagnosticHandler { 322 public: 323 PyDiagnosticHandler(MlirContext context, pybind11::object callback); 324 ~PyDiagnosticHandler(); 325 326 bool isAttached() { return registeredID.hasValue(); } 327 bool getHadError() { return hadError; } 328 329 /// Detaches the handler. Does nothing if not attached. 330 void detach(); 331 332 pybind11::object contextEnter() { return pybind11::cast(this); } 333 void contextExit(pybind11::object excType, pybind11::object excVal, 334 pybind11::object excTb) { 335 detach(); 336 } 337 338 private: 339 MlirContext context; 340 pybind11::object callback; 341 llvm::Optional<MlirDiagnosticHandlerID> registeredID; 342 bool hadError = false; 343 friend class PyMlirContext; 344 }; 345 346 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 347 /// order to differentiate it from the `Dialect` base class which is extended by 348 /// plugins which extend dialect functionality through extension python code. 349 /// This should be seen as the "low-level" object and `Dialect` as the 350 /// high-level, user facing object. 351 class PyDialectDescriptor : public BaseContextObject { 352 public: 353 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 354 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 355 356 MlirDialect get() { return dialect; } 357 358 private: 359 MlirDialect dialect; 360 }; 361 362 /// User-level object for accessing dialects with dotted syntax such as: 363 /// ctx.dialect.std 364 class PyDialects : public BaseContextObject { 365 public: 366 PyDialects(PyMlirContextRef contextRef) 367 : BaseContextObject(std::move(contextRef)) {} 368 369 MlirDialect getDialectForKey(const std::string &key, bool attrError); 370 }; 371 372 /// User-level dialect object. For dialects that have a registered extension, 373 /// this will be the base class of the extension dialect type. For un-extended, 374 /// objects of this type will be returned directly. 375 class PyDialect { 376 public: 377 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 378 379 pybind11::object getDescriptor() { return descriptor; } 380 381 private: 382 pybind11::object descriptor; 383 }; 384 385 /// Wrapper around an MlirLocation. 386 class PyLocation : public BaseContextObject { 387 public: 388 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 389 : BaseContextObject(std::move(contextRef)), loc(loc) {} 390 391 operator MlirLocation() const { return loc; } 392 MlirLocation get() const { return loc; } 393 394 /// Enter and exit the context manager. 395 pybind11::object contextEnter(); 396 void contextExit(const pybind11::object &excType, 397 const pybind11::object &excVal, 398 const pybind11::object &excTb); 399 400 /// Gets a capsule wrapping the void* within the MlirLocation. 401 pybind11::object getCapsule(); 402 403 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 404 /// Note that PyLocation instances are uniqued, so the returned object 405 /// may be a pre-existing object. Ownership of the underlying MlirLocation 406 /// is taken by calling this function. 407 static PyLocation createFromCapsule(pybind11::object capsule); 408 409 private: 410 MlirLocation loc; 411 }; 412 413 /// Used in function arguments when None should resolve to the current context 414 /// manager set instance. 415 class DefaultingPyLocation 416 : public Defaulting<DefaultingPyLocation, PyLocation> { 417 public: 418 using Defaulting::Defaulting; 419 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 420 static PyLocation &resolve(); 421 422 operator MlirLocation() const { return *get(); } 423 }; 424 425 /// Wrapper around MlirModule. 426 /// This is the top-level, user-owned object that contains regions/ops/blocks. 427 class PyModule; 428 using PyModuleRef = PyObjectRef<PyModule>; 429 class PyModule : public BaseContextObject { 430 public: 431 /// Returns a PyModule reference for the given MlirModule. This may return 432 /// a pre-existing or new object. 433 static PyModuleRef forModule(MlirModule module); 434 PyModule(PyModule &) = delete; 435 PyModule(PyMlirContext &&) = delete; 436 ~PyModule(); 437 438 /// Gets the backing MlirModule. 439 MlirModule get() { return module; } 440 441 /// Gets a strong reference to this module. 442 PyModuleRef getRef() { 443 return PyModuleRef(this, 444 pybind11::reinterpret_borrow<pybind11::object>(handle)); 445 } 446 447 /// Gets a capsule wrapping the void* within the MlirModule. 448 /// Note that the module does not (yet) provide a corresponding factory for 449 /// constructing from a capsule as that would require uniquing PyModule 450 /// instances, which is not currently done. 451 pybind11::object getCapsule(); 452 453 /// Creates a PyModule from the MlirModule wrapped by a capsule. 454 /// Note that PyModule instances are uniqued, so the returned object 455 /// may be a pre-existing object. Ownership of the underlying MlirModule 456 /// is taken by calling this function. 457 static pybind11::object createFromCapsule(pybind11::object capsule); 458 459 private: 460 PyModule(PyMlirContextRef contextRef, MlirModule module); 461 MlirModule module; 462 pybind11::handle handle; 463 }; 464 465 /// Base class for PyOperation and PyOpView which exposes the primary, user 466 /// visible methods for manipulating it. 467 class PyOperationBase { 468 public: 469 virtual ~PyOperationBase() = default; 470 /// Implements the bound 'print' method and helps with others. 471 void print(pybind11::object fileObject, bool binary, 472 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 473 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 474 bool assumeVerified); 475 pybind11::object getAsm(bool binary, 476 llvm::Optional<int64_t> largeElementsLimit, 477 bool enableDebugInfo, bool prettyDebugInfo, 478 bool printGenericOpForm, bool useLocalScope, 479 bool assumeVerified); 480 481 /// Moves the operation before or after the other operation. 482 void moveAfter(PyOperationBase &other); 483 void moveBefore(PyOperationBase &other); 484 485 /// Each must provide access to the raw Operation. 486 virtual PyOperation &getOperation() = 0; 487 }; 488 489 /// Wrapper around PyOperation. 490 /// Operations exist in either an attached (dependent) or detached (top-level) 491 /// state. In the detached state (as on creation), an operation is owned by 492 /// the creator and its lifetime extends either until its reference count 493 /// drops to zero or it is attached to a parent, at which point its lifetime 494 /// is bounded by its top-level parent reference. 495 class PyOperation; 496 using PyOperationRef = PyObjectRef<PyOperation>; 497 class PyOperation : public PyOperationBase, public BaseContextObject { 498 public: 499 ~PyOperation(); 500 PyOperation &getOperation() override { return *this; } 501 502 /// Returns a PyOperation for the given MlirOperation, optionally associating 503 /// it with a parentKeepAlive. 504 static PyOperationRef 505 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 506 pybind11::object parentKeepAlive = pybind11::object()); 507 508 /// Creates a detached operation. The operation must not be associated with 509 /// any existing live operation. 510 static PyOperationRef 511 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 512 pybind11::object parentKeepAlive = pybind11::object()); 513 514 /// Detaches the operation from its parent block and updates its state 515 /// accordingly. 516 void detachFromParent() { 517 mlirOperationRemoveFromParent(getOperation()); 518 setDetached(); 519 parentKeepAlive = pybind11::object(); 520 } 521 522 /// Gets the backing operation. 523 operator MlirOperation() const { return get(); } 524 MlirOperation get() const { 525 checkValid(); 526 return operation; 527 } 528 529 PyOperationRef getRef() { 530 return PyOperationRef( 531 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 532 } 533 534 bool isAttached() { return attached; } 535 void setAttached(pybind11::object parent = pybind11::object()) { 536 assert(!attached && "operation already attached"); 537 attached = true; 538 } 539 void setDetached() { 540 assert(attached && "operation already detached"); 541 attached = false; 542 } 543 void checkValid() const; 544 545 /// Gets the owning block or raises an exception if the operation has no 546 /// owning block. 547 PyBlock getBlock(); 548 549 /// Gets the parent operation or raises an exception if the operation has 550 /// no parent. 551 llvm::Optional<PyOperationRef> getParentOperation(); 552 553 /// Gets a capsule wrapping the void* within the MlirOperation. 554 pybind11::object getCapsule(); 555 556 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 557 /// Ownership of the underlying MlirOperation is taken by calling this 558 /// function. 559 static pybind11::object createFromCapsule(pybind11::object capsule); 560 561 /// Creates an operation. See corresponding python docstring. 562 static pybind11::object 563 create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 564 llvm::Optional<std::vector<PyValue *>> operands, 565 llvm::Optional<pybind11::dict> attributes, 566 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 567 DefaultingPyLocation location, const pybind11::object &ip); 568 569 /// Creates an OpView suitable for this operation. 570 pybind11::object createOpView(); 571 572 /// Erases the underlying MlirOperation, removes its pointer from the 573 /// parent context's live operations map, and sets the valid bit false. 574 void erase(); 575 576 private: 577 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 578 static PyOperationRef createInstance(PyMlirContextRef contextRef, 579 MlirOperation operation, 580 pybind11::object parentKeepAlive); 581 582 MlirOperation operation; 583 pybind11::handle handle; 584 // Keeps the parent alive, regardless of whether it is an Operation or 585 // Module. 586 // TODO: As implemented, this facility is only sufficient for modeling the 587 // trivial module parent back-reference. Generalize this to also account for 588 // transitions from detached to attached and address TODOs in the 589 // ir_operation.py regarding testing corresponding lifetime guarantees. 590 pybind11::object parentKeepAlive; 591 bool attached = true; 592 bool valid = true; 593 594 friend class PyOperationBase; 595 friend class PySymbolTable; 596 }; 597 598 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 599 /// providing more instance-specific accessors and serve as the base class for 600 /// custom ODS-style operation classes. Since this class is subclass on the 601 /// python side, it must present an __init__ method that operates in pure 602 /// python types. 603 class PyOpView : public PyOperationBase { 604 public: 605 PyOpView(const pybind11::object &operationObject); 606 PyOperation &getOperation() override { return operation; } 607 608 static pybind11::object createRawSubclass(const pybind11::object &userClass); 609 610 pybind11::object getOperationObject() { return operationObject; } 611 612 static pybind11::object 613 buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 614 pybind11::list operandList, 615 llvm::Optional<pybind11::dict> attributes, 616 llvm::Optional<std::vector<PyBlock *>> successors, 617 llvm::Optional<int> regions, DefaultingPyLocation location, 618 const pybind11::object &maybeIp); 619 620 private: 621 PyOperation &operation; // For efficient, cast-free access from C++ 622 pybind11::object operationObject; // Holds the reference. 623 }; 624 625 /// Wrapper around an MlirRegion. 626 /// Regions are managed completely by their containing operation. Unlike the 627 /// C++ API, the python API does not support detached regions. 628 class PyRegion { 629 public: 630 PyRegion(PyOperationRef parentOperation, MlirRegion region) 631 : parentOperation(std::move(parentOperation)), region(region) { 632 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 633 } 634 operator MlirRegion() const { return region; } 635 636 MlirRegion get() { return region; } 637 PyOperationRef &getParentOperation() { return parentOperation; } 638 639 void checkValid() { return parentOperation->checkValid(); } 640 641 private: 642 PyOperationRef parentOperation; 643 MlirRegion region; 644 }; 645 646 /// Wrapper around an MlirBlock. 647 /// Blocks are managed completely by their containing operation. Unlike the 648 /// C++ API, the python API does not support detached blocks. 649 class PyBlock { 650 public: 651 PyBlock(PyOperationRef parentOperation, MlirBlock block) 652 : parentOperation(std::move(parentOperation)), block(block) { 653 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 654 } 655 656 MlirBlock get() { return block; } 657 PyOperationRef &getParentOperation() { return parentOperation; } 658 659 void checkValid() { return parentOperation->checkValid(); } 660 661 private: 662 PyOperationRef parentOperation; 663 MlirBlock block; 664 }; 665 666 /// An insertion point maintains a pointer to a Block and a reference operation. 667 /// Calls to insert() will insert a new operation before the 668 /// reference operation. If the reference operation is null, then appends to 669 /// the end of the block. 670 class PyInsertionPoint { 671 public: 672 /// Creates an insertion point positioned after the last operation in the 673 /// block, but still inside the block. 674 PyInsertionPoint(PyBlock &block); 675 /// Creates an insertion point positioned before a reference operation. 676 PyInsertionPoint(PyOperationBase &beforeOperationBase); 677 678 /// Shortcut to create an insertion point at the beginning of the block. 679 static PyInsertionPoint atBlockBegin(PyBlock &block); 680 /// Shortcut to create an insertion point before the block terminator. 681 static PyInsertionPoint atBlockTerminator(PyBlock &block); 682 683 /// Inserts an operation. 684 void insert(PyOperationBase &operationBase); 685 686 /// Enter and exit the context manager. 687 pybind11::object contextEnter(); 688 void contextExit(const pybind11::object &excType, 689 const pybind11::object &excVal, 690 const pybind11::object &excTb); 691 692 PyBlock &getBlock() { return block; } 693 694 private: 695 // Trampoline constructor that avoids null initializing members while 696 // looking up parents. 697 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 698 : refOperation(std::move(refOperation)), block(std::move(block)) {} 699 700 llvm::Optional<PyOperationRef> refOperation; 701 PyBlock block; 702 }; 703 /// Wrapper around the generic MlirType. 704 /// The lifetime of a type is bound by the PyContext that created it. 705 class PyType : public BaseContextObject { 706 public: 707 PyType(PyMlirContextRef contextRef, MlirType type) 708 : BaseContextObject(std::move(contextRef)), type(type) {} 709 bool operator==(const PyType &other); 710 operator MlirType() const { return type; } 711 MlirType get() const { return type; } 712 713 /// Gets a capsule wrapping the void* within the MlirType. 714 pybind11::object getCapsule(); 715 716 /// Creates a PyType from the MlirType wrapped by a capsule. 717 /// Note that PyType instances are uniqued, so the returned object 718 /// may be a pre-existing object. Ownership of the underlying MlirType 719 /// is taken by calling this function. 720 static PyType createFromCapsule(pybind11::object capsule); 721 722 private: 723 MlirType type; 724 }; 725 726 /// CRTP base classes for Python types that subclass Type and should be 727 /// castable from it (i.e. via something like IntegerType(t)). 728 /// By default, type class hierarchies are one level deep (i.e. a 729 /// concrete type class extends PyType); however, intermediate python-visible 730 /// base classes can be modeled by specifying a BaseTy. 731 template <typename DerivedTy, typename BaseTy = PyType> 732 class PyConcreteType : public BaseTy { 733 public: 734 // Derived classes must define statics for: 735 // IsAFunctionTy isaFunction 736 // const char *pyClassName 737 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 738 using IsAFunctionTy = bool (*)(MlirType); 739 740 PyConcreteType() = default; 741 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 742 : BaseTy(std::move(contextRef), t) {} 743 PyConcreteType(PyType &orig) 744 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 745 746 static MlirType castFrom(PyType &orig) { 747 if (!DerivedTy::isaFunction(orig)) { 748 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 749 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 750 DerivedTy::pyClassName + 751 " (from " + origRepr + ")"); 752 } 753 return orig; 754 } 755 756 static void bind(pybind11::module &m) { 757 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 758 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 759 pybind11::arg("cast_from_type")); 760 cls.def_static( 761 "isinstance", 762 [](PyType &otherType) -> bool { 763 return DerivedTy::isaFunction(otherType); 764 }, 765 pybind11::arg("other")); 766 DerivedTy::bindDerived(cls); 767 } 768 769 /// Implemented by derived classes to add methods to the Python subclass. 770 static void bindDerived(ClassTy &m) {} 771 }; 772 773 /// Wrapper around the generic MlirAttribute. 774 /// The lifetime of a type is bound by the PyContext that created it. 775 class PyAttribute : public BaseContextObject { 776 public: 777 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 778 : BaseContextObject(std::move(contextRef)), attr(attr) {} 779 bool operator==(const PyAttribute &other); 780 operator MlirAttribute() const { return attr; } 781 MlirAttribute get() const { return attr; } 782 783 /// Gets a capsule wrapping the void* within the MlirAttribute. 784 pybind11::object getCapsule(); 785 786 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 787 /// Note that PyAttribute instances are uniqued, so the returned object 788 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 789 /// is taken by calling this function. 790 static PyAttribute createFromCapsule(pybind11::object capsule); 791 792 private: 793 MlirAttribute attr; 794 }; 795 796 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 797 /// TODO: Refactor this and the C-API to be based on an Identifier owned 798 /// by the context so as to avoid ownership issues here. 799 class PyNamedAttribute { 800 public: 801 /// Constructs a PyNamedAttr that retains an owned name. This should be 802 /// used in any code that originates an MlirNamedAttribute from a python 803 /// string. 804 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 805 /// passed attribute. 806 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 807 808 MlirNamedAttribute namedAttr; 809 810 private: 811 // Since the MlirNamedAttr contains an internal pointer to the actual 812 // memory of the owned string, it must be heap allocated to remain valid. 813 // Otherwise, strings that fit within the small object optimization threshold 814 // will have their memory address change as the containing object is moved, 815 // resulting in an invalid aliased pointer. 816 std::unique_ptr<std::string> ownedName; 817 }; 818 819 /// CRTP base classes for Python attributes that subclass Attribute and should 820 /// be castable from it (i.e. via something like StringAttr(attr)). 821 /// By default, attribute class hierarchies are one level deep (i.e. a 822 /// concrete attribute class extends PyAttribute); however, intermediate 823 /// python-visible base classes can be modeled by specifying a BaseTy. 824 template <typename DerivedTy, typename BaseTy = PyAttribute> 825 class PyConcreteAttribute : public BaseTy { 826 public: 827 // Derived classes must define statics for: 828 // IsAFunctionTy isaFunction 829 // const char *pyClassName 830 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 831 using IsAFunctionTy = bool (*)(MlirAttribute); 832 833 PyConcreteAttribute() = default; 834 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 835 : BaseTy(std::move(contextRef), attr) {} 836 PyConcreteAttribute(PyAttribute &orig) 837 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 838 839 static MlirAttribute castFrom(PyAttribute &orig) { 840 if (!DerivedTy::isaFunction(orig)) { 841 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 842 throw SetPyError(PyExc_ValueError, 843 llvm::Twine("Cannot cast attribute to ") + 844 DerivedTy::pyClassName + " (from " + origRepr + ")"); 845 } 846 return orig; 847 } 848 849 static void bind(pybind11::module &m) { 850 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 851 pybind11::module_local()); 852 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 853 pybind11::arg("cast_from_attr")); 854 cls.def_static( 855 "isinstance", 856 [](PyAttribute &otherAttr) -> bool { 857 return DerivedTy::isaFunction(otherAttr); 858 }, 859 pybind11::arg("other")); 860 cls.def_property_readonly("type", [](PyAttribute &attr) { 861 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 862 }); 863 DerivedTy::bindDerived(cls); 864 } 865 866 /// Implemented by derived classes to add methods to the Python subclass. 867 static void bindDerived(ClassTy &m) {} 868 }; 869 870 /// Wrapper around the generic MlirValue. 871 /// Values are managed completely by the operation that resulted in their 872 /// definition. For op result value, this is the operation that defines the 873 /// value. For block argument values, this is the operation that contains the 874 /// block to which the value is an argument (blocks cannot be detached in Python 875 /// bindings so such operation always exists). 876 class PyValue { 877 public: 878 PyValue(PyOperationRef parentOperation, MlirValue value) 879 : parentOperation(parentOperation), value(value) {} 880 operator MlirValue() const { return value; } 881 882 MlirValue get() { return value; } 883 PyOperationRef &getParentOperation() { return parentOperation; } 884 885 void checkValid() { return parentOperation->checkValid(); } 886 887 /// Gets a capsule wrapping the void* within the MlirValue. 888 pybind11::object getCapsule(); 889 890 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 891 /// the underlying MlirValue is still tied to the owning operation. 892 static PyValue createFromCapsule(pybind11::object capsule); 893 894 private: 895 PyOperationRef parentOperation; 896 MlirValue value; 897 }; 898 899 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 900 class PyAffineExpr : public BaseContextObject { 901 public: 902 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 903 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 904 bool operator==(const PyAffineExpr &other); 905 operator MlirAffineExpr() const { return affineExpr; } 906 MlirAffineExpr get() const { return affineExpr; } 907 908 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 909 pybind11::object getCapsule(); 910 911 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 912 /// Note that PyAffineExpr instances are uniqued, so the returned object 913 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 914 /// is taken by calling this function. 915 static PyAffineExpr createFromCapsule(pybind11::object capsule); 916 917 PyAffineExpr add(const PyAffineExpr &other) const; 918 PyAffineExpr mul(const PyAffineExpr &other) const; 919 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 920 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 921 PyAffineExpr mod(const PyAffineExpr &other) const; 922 923 private: 924 MlirAffineExpr affineExpr; 925 }; 926 927 class PyAffineMap : public BaseContextObject { 928 public: 929 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 930 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 931 bool operator==(const PyAffineMap &other); 932 operator MlirAffineMap() const { return affineMap; } 933 MlirAffineMap get() const { return affineMap; } 934 935 /// Gets a capsule wrapping the void* within the MlirAffineMap. 936 pybind11::object getCapsule(); 937 938 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 939 /// Note that PyAffineMap instances are uniqued, so the returned object 940 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 941 /// is taken by calling this function. 942 static PyAffineMap createFromCapsule(pybind11::object capsule); 943 944 private: 945 MlirAffineMap affineMap; 946 }; 947 948 class PyIntegerSet : public BaseContextObject { 949 public: 950 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 951 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 952 bool operator==(const PyIntegerSet &other); 953 operator MlirIntegerSet() const { return integerSet; } 954 MlirIntegerSet get() const { return integerSet; } 955 956 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 957 pybind11::object getCapsule(); 958 959 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 960 /// Note that PyIntegerSet instances may be uniqued, so the returned object 961 /// may be a pre-existing object. Integer sets are owned by the context. 962 static PyIntegerSet createFromCapsule(pybind11::object capsule); 963 964 private: 965 MlirIntegerSet integerSet; 966 }; 967 968 /// Bindings for MLIR symbol tables. 969 class PySymbolTable { 970 public: 971 /// Constructs a symbol table for the given operation. 972 explicit PySymbolTable(PyOperationBase &operation); 973 974 /// Destroys the symbol table. 975 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 976 977 /// Returns the symbol (opview) with the given name, throws if there is no 978 /// such symbol in the table. 979 pybind11::object dunderGetItem(const std::string &name); 980 981 /// Removes the given operation from the symbol table and erases it. 982 void erase(PyOperationBase &symbol); 983 984 /// Removes the operation with the given name from the symbol table and erases 985 /// it, throws if there is no such symbol in the table. 986 void dunderDel(const std::string &name); 987 988 /// Inserts the given operation into the symbol table. The operation must have 989 /// the symbol trait. 990 PyAttribute insert(PyOperationBase &symbol); 991 992 /// Gets and sets the name of a symbol op. 993 static PyAttribute getSymbolName(PyOperationBase &symbol); 994 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 995 996 /// Gets and sets the visibility of a symbol op. 997 static PyAttribute getVisibility(PyOperationBase &symbol); 998 static void setVisibility(PyOperationBase &symbol, 999 const std::string &visibility); 1000 1001 /// Replaces all symbol uses within an operation. See the API 1002 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1003 static void replaceAllSymbolUses(const std::string &oldSymbol, 1004 const std::string &newSymbol, 1005 PyOperationBase &from); 1006 1007 /// Walks all symbol tables under and including 'from'. 1008 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1009 pybind11::object callback); 1010 1011 /// Casts the bindings class into the C API structure. 1012 operator MlirSymbolTable() { return symbolTable; } 1013 1014 private: 1015 PyOperationRef operation; 1016 MlirSymbolTable symbolTable; 1017 }; 1018 1019 void populateIRAffine(pybind11::module &m); 1020 void populateIRAttributes(pybind11::module &m); 1021 void populateIRCore(pybind11::module &m); 1022 void populateIRInterfaces(pybind11::module &m); 1023 void populateIRTypes(pybind11::module &m); 1024 1025 } // namespace python 1026 } // namespace mlir 1027 1028 namespace pybind11 { 1029 namespace detail { 1030 1031 template <> 1032 struct type_caster<mlir::python::DefaultingPyMlirContext> 1033 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1034 template <> 1035 struct type_caster<mlir::python::DefaultingPyLocation> 1036 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1037 1038 } // namespace detail 1039 } // namespace pybind11 1040 1041 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1042