1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <utility> 13 #include <vector> 14 15 #include "PybindUtils.h" 16 17 #include "mlir-c/AffineExpr.h" 18 #include "mlir-c/AffineMap.h" 19 #include "mlir-c/Diagnostics.h" 20 #include "mlir-c/IR.h" 21 #include "mlir-c/IntegerSet.h" 22 #include "llvm/ADT/DenseMap.h" 23 #include "llvm/ADT/Optional.h" 24 25 namespace mlir { 26 namespace python { 27 28 class PyBlock; 29 class PyDiagnostic; 30 class PyDiagnosticHandler; 31 class PyInsertionPoint; 32 class PyLocation; 33 class DefaultingPyLocation; 34 class PyMlirContext; 35 class DefaultingPyMlirContext; 36 class PyModule; 37 class PyOperation; 38 class PyType; 39 class PySymbolTable; 40 class PyValue; 41 42 /// Template for a reference to a concrete type which captures a python 43 /// reference to its underlying python object. 44 template <typename T> 45 class PyObjectRef { 46 public: 47 PyObjectRef(T *referrent, pybind11::object object) 48 : referrent(referrent), object(std::move(object)) { 49 assert(this->referrent && 50 "cannot construct PyObjectRef with null referrent"); 51 assert(this->object && "cannot construct PyObjectRef with null object"); 52 } 53 PyObjectRef(PyObjectRef &&other) 54 : referrent(other.referrent), object(std::move(other.object)) { 55 other.referrent = nullptr; 56 assert(!other.object); 57 } 58 PyObjectRef(const PyObjectRef &other) 59 : referrent(other.referrent), object(other.object /* copies */) {} 60 ~PyObjectRef() = default; 61 62 int getRefCount() { 63 if (!object) 64 return 0; 65 return object.ref_count(); 66 } 67 68 /// Releases the object held by this instance, returning it. 69 /// This is the proper thing to return from a function that wants to return 70 /// the reference. Note that this does not work from initializers. 71 pybind11::object releaseObject() { 72 assert(referrent && object); 73 referrent = nullptr; 74 auto stolen = std::move(object); 75 return stolen; 76 } 77 78 T *get() { return referrent; } 79 T *operator->() { 80 assert(referrent && object); 81 return referrent; 82 } 83 pybind11::object getObject() { 84 assert(referrent && object); 85 return object; 86 } 87 operator bool() const { return referrent && object; } 88 89 private: 90 T *referrent; 91 pybind11::object object; 92 }; 93 94 /// Tracks an entry in the thread context stack. New entries are pushed onto 95 /// here for each with block that activates a new InsertionPoint, Context or 96 /// Location. 97 /// 98 /// Pushing either a Location or InsertionPoint also pushes its associated 99 /// Context. Pushing a Context will not modify the Location or InsertionPoint 100 /// unless if they are from a different context, in which case, they are 101 /// cleared. 102 class PyThreadContextEntry { 103 public: 104 enum class FrameKind { 105 Context, 106 InsertionPoint, 107 Location, 108 }; 109 110 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 111 pybind11::object insertionPoint, 112 pybind11::object location) 113 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 114 location(std::move(location)), frameKind(frameKind) {} 115 116 /// Gets the top of stack context and return nullptr if not defined. 117 static PyMlirContext *getDefaultContext(); 118 119 /// Gets the top of stack insertion point and return nullptr if not defined. 120 static PyInsertionPoint *getDefaultInsertionPoint(); 121 122 /// Gets the top of stack location and returns nullptr if not defined. 123 static PyLocation *getDefaultLocation(); 124 125 PyMlirContext *getContext(); 126 PyInsertionPoint *getInsertionPoint(); 127 PyLocation *getLocation(); 128 FrameKind getFrameKind() { return frameKind; } 129 130 /// Stack management. 131 static PyThreadContextEntry *getTopOfStack(); 132 static pybind11::object pushContext(PyMlirContext &context); 133 static void popContext(PyMlirContext &context); 134 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 135 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 136 static pybind11::object pushLocation(PyLocation &location); 137 static void popLocation(PyLocation &location); 138 139 /// Gets the thread local stack. 140 static std::vector<PyThreadContextEntry> &getStack(); 141 142 private: 143 static void push(FrameKind frameKind, pybind11::object context, 144 pybind11::object insertionPoint, pybind11::object location); 145 146 /// An object reference to the PyContext. 147 pybind11::object context; 148 /// An object reference to the current insertion point. 149 pybind11::object insertionPoint; 150 /// An object reference to the current location. 151 pybind11::object location; 152 // The kind of push that was performed. 153 FrameKind frameKind; 154 }; 155 156 /// Wrapper around MlirContext. 157 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 158 class PyMlirContext { 159 public: 160 PyMlirContext() = delete; 161 PyMlirContext(const PyMlirContext &) = delete; 162 PyMlirContext(PyMlirContext &&) = delete; 163 164 /// For the case of a python __init__ (py::init) method, pybind11 is quite 165 /// strict about needing to return a pointer that is not yet associated to 166 /// an py::object. Since the forContext() method acts like a pool, possibly 167 /// returning a recycled context, it does not satisfy this need. The usual 168 /// way in python to accomplish such a thing is to override __new__, but 169 /// that is also not supported by pybind11. Instead, we use this entry 170 /// point which always constructs a fresh context (which cannot alias an 171 /// existing one because it is fresh). 172 static PyMlirContext *createNewContextForInit(); 173 174 /// Returns a context reference for the singleton PyMlirContext wrapper for 175 /// the given context. 176 static PyMlirContextRef forContext(MlirContext context); 177 ~PyMlirContext(); 178 179 /// Accesses the underlying MlirContext. 180 MlirContext get() { return context; } 181 182 /// Gets a strong reference to this context, which will ensure it is kept 183 /// alive for the life of the reference. 184 PyMlirContextRef getRef() { 185 return PyMlirContextRef(this, pybind11::cast(this)); 186 } 187 188 /// Gets a capsule wrapping the void* within the MlirContext. 189 pybind11::object getCapsule(); 190 191 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 192 /// Note that PyMlirContext instances are uniqued, so the returned object 193 /// may be a pre-existing object. Ownership of the underlying MlirContext 194 /// is taken by calling this function. 195 static pybind11::object createFromCapsule(pybind11::object capsule); 196 197 /// Gets the count of live context objects. Used for testing. 198 static size_t getLiveCount(); 199 200 /// Gets the count of live operations associated with this context. 201 /// Used for testing. 202 size_t getLiveOperationCount(); 203 204 /// Gets the count of live modules associated with this context. 205 /// Used for testing. 206 size_t getLiveModuleCount(); 207 208 /// Enter and exit the context manager. 209 pybind11::object contextEnter(); 210 void contextExit(const pybind11::object &excType, 211 const pybind11::object &excVal, 212 const pybind11::object &excTb); 213 214 /// Attaches a Python callback as a diagnostic handler, returning a 215 /// registration object (internally a PyDiagnosticHandler). 216 pybind11::object attachDiagnosticHandler(pybind11::object callback); 217 218 private: 219 PyMlirContext(MlirContext context); 220 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 221 // preserving the relationship that an MlirContext maps to a single 222 // PyMlirContext wrapper. This could be replaced in the future with an 223 // extension mechanism on the MlirContext for stashing user pointers. 224 // Note that this holds a handle, which does not imply ownership. 225 // Mappings will be removed when the context is destructed. 226 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 227 static LiveContextMap &getLiveContexts(); 228 229 // Interns all live modules associated with this context. Modules tracked 230 // in this map are valid. When a module is invalidated, it is removed 231 // from this map, and while it still exists as an instance, any 232 // attempt to access it will raise an error. 233 using LiveModuleMap = 234 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 235 LiveModuleMap liveModules; 236 237 // Interns all live operations associated with this context. Operations 238 // tracked in this map are valid. When an operation is invalidated, it is 239 // removed from this map, and while it still exists as an instance, any 240 // attempt to access it will raise an error. 241 using LiveOperationMap = 242 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 243 LiveOperationMap liveOperations; 244 245 MlirContext context; 246 friend class PyModule; 247 friend class PyOperation; 248 }; 249 250 /// Used in function arguments when None should resolve to the current context 251 /// manager set instance. 252 class DefaultingPyMlirContext 253 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 254 public: 255 using Defaulting::Defaulting; 256 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 257 static PyMlirContext &resolve(); 258 }; 259 260 /// Base class for all objects that directly or indirectly depend on an 261 /// MlirContext. The lifetime of the context will extend at least to the 262 /// lifetime of these instances. 263 /// Immutable objects that depend on a context extend this directly. 264 class BaseContextObject { 265 public: 266 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 267 assert(this->contextRef && 268 "context object constructed with null context ref"); 269 } 270 271 /// Accesses the context reference. 272 PyMlirContextRef &getContext() { return contextRef; } 273 274 private: 275 PyMlirContextRef contextRef; 276 }; 277 278 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 279 /// are only valid for the duration of a diagnostic callback and attempting 280 /// to access them outside of that will raise an exception. This applies to 281 /// nested diagnostics (in the notes) as well. 282 class PyDiagnostic { 283 public: 284 PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 285 void invalidate(); 286 bool isValid() { return valid; } 287 MlirDiagnosticSeverity getSeverity(); 288 PyLocation getLocation(); 289 pybind11::str getMessage(); 290 pybind11::tuple getNotes(); 291 292 private: 293 MlirDiagnostic diagnostic; 294 295 void checkValid(); 296 /// If notes have been materialized from the diagnostic, then this will 297 /// be populated with the corresponding objects (all castable to 298 /// PyDiagnostic). 299 llvm::Optional<pybind11::tuple> materializedNotes; 300 bool valid = true; 301 }; 302 303 /// Represents a diagnostic handler attached to the context. The handler's 304 /// callback will be invoked with PyDiagnostic instances until the detach() 305 /// method is called or the context is destroyed. A diagnostic handler can be 306 /// the subject of a `with` block, which will detach it when the block exits. 307 /// 308 /// Since diagnostic handlers can call back into Python code which can do 309 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 310 /// etc), this is generally not deemed to be a great user-level API. Users 311 /// should generally use some form of DiagnosticCollector. If the handler raises 312 /// any exceptions, they will just be emitted to stderr and dropped. 313 /// 314 /// The unique usage of this class means that its lifetime management is 315 /// different from most other parts of the API. Instances are always created 316 /// in an attached state and can transition to a detached state by either: 317 /// a) The context being destroyed and unregistering all handlers. 318 /// b) An explicit call to detach(). 319 /// The object may remain live from a Python perspective for an arbitrary time 320 /// after detachment, but there is nothing the user can do with it (since there 321 /// is no way to attach an existing handler object). 322 class PyDiagnosticHandler { 323 public: 324 PyDiagnosticHandler(MlirContext context, pybind11::object callback); 325 ~PyDiagnosticHandler(); 326 327 bool isAttached() { return registeredID.hasValue(); } 328 bool getHadError() { return hadError; } 329 330 /// Detaches the handler. Does nothing if not attached. 331 void detach(); 332 333 pybind11::object contextEnter() { return pybind11::cast(this); } 334 void contextExit(const pybind11::object &excType, 335 const pybind11::object &excVal, 336 const pybind11::object &excTb) { 337 detach(); 338 } 339 340 private: 341 MlirContext context; 342 pybind11::object callback; 343 llvm::Optional<MlirDiagnosticHandlerID> registeredID; 344 bool hadError = false; 345 friend class PyMlirContext; 346 }; 347 348 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 349 /// order to differentiate it from the `Dialect` base class which is extended by 350 /// plugins which extend dialect functionality through extension python code. 351 /// This should be seen as the "low-level" object and `Dialect` as the 352 /// high-level, user facing object. 353 class PyDialectDescriptor : public BaseContextObject { 354 public: 355 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 356 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 357 358 MlirDialect get() { return dialect; } 359 360 private: 361 MlirDialect dialect; 362 }; 363 364 /// User-level object for accessing dialects with dotted syntax such as: 365 /// ctx.dialect.std 366 class PyDialects : public BaseContextObject { 367 public: 368 PyDialects(PyMlirContextRef contextRef) 369 : BaseContextObject(std::move(contextRef)) {} 370 371 MlirDialect getDialectForKey(const std::string &key, bool attrError); 372 }; 373 374 /// User-level dialect object. For dialects that have a registered extension, 375 /// this will be the base class of the extension dialect type. For un-extended, 376 /// objects of this type will be returned directly. 377 class PyDialect { 378 public: 379 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 380 381 pybind11::object getDescriptor() { return descriptor; } 382 383 private: 384 pybind11::object descriptor; 385 }; 386 387 /// Wrapper around an MlirLocation. 388 class PyLocation : public BaseContextObject { 389 public: 390 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 391 : BaseContextObject(std::move(contextRef)), loc(loc) {} 392 393 operator MlirLocation() const { return loc; } 394 MlirLocation get() const { return loc; } 395 396 /// Enter and exit the context manager. 397 pybind11::object contextEnter(); 398 void contextExit(const pybind11::object &excType, 399 const pybind11::object &excVal, 400 const pybind11::object &excTb); 401 402 /// Gets a capsule wrapping the void* within the MlirLocation. 403 pybind11::object getCapsule(); 404 405 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 406 /// Note that PyLocation instances are uniqued, so the returned object 407 /// may be a pre-existing object. Ownership of the underlying MlirLocation 408 /// is taken by calling this function. 409 static PyLocation createFromCapsule(pybind11::object capsule); 410 411 private: 412 MlirLocation loc; 413 }; 414 415 /// Used in function arguments when None should resolve to the current context 416 /// manager set instance. 417 class DefaultingPyLocation 418 : public Defaulting<DefaultingPyLocation, PyLocation> { 419 public: 420 using Defaulting::Defaulting; 421 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 422 static PyLocation &resolve(); 423 424 operator MlirLocation() const { return *get(); } 425 }; 426 427 /// Wrapper around MlirModule. 428 /// This is the top-level, user-owned object that contains regions/ops/blocks. 429 class PyModule; 430 using PyModuleRef = PyObjectRef<PyModule>; 431 class PyModule : public BaseContextObject { 432 public: 433 /// Returns a PyModule reference for the given MlirModule. This may return 434 /// a pre-existing or new object. 435 static PyModuleRef forModule(MlirModule module); 436 PyModule(PyModule &) = delete; 437 PyModule(PyMlirContext &&) = delete; 438 ~PyModule(); 439 440 /// Gets the backing MlirModule. 441 MlirModule get() { return module; } 442 443 /// Gets a strong reference to this module. 444 PyModuleRef getRef() { 445 return PyModuleRef(this, 446 pybind11::reinterpret_borrow<pybind11::object>(handle)); 447 } 448 449 /// Gets a capsule wrapping the void* within the MlirModule. 450 /// Note that the module does not (yet) provide a corresponding factory for 451 /// constructing from a capsule as that would require uniquing PyModule 452 /// instances, which is not currently done. 453 pybind11::object getCapsule(); 454 455 /// Creates a PyModule from the MlirModule wrapped by a capsule. 456 /// Note that PyModule instances are uniqued, so the returned object 457 /// may be a pre-existing object. Ownership of the underlying MlirModule 458 /// is taken by calling this function. 459 static pybind11::object createFromCapsule(pybind11::object capsule); 460 461 private: 462 PyModule(PyMlirContextRef contextRef, MlirModule module); 463 MlirModule module; 464 pybind11::handle handle; 465 }; 466 467 /// Base class for PyOperation and PyOpView which exposes the primary, user 468 /// visible methods for manipulating it. 469 class PyOperationBase { 470 public: 471 virtual ~PyOperationBase() = default; 472 /// Implements the bound 'print' method and helps with others. 473 void print(pybind11::object fileObject, bool binary, 474 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 475 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 476 bool assumeVerified); 477 pybind11::object getAsm(bool binary, 478 llvm::Optional<int64_t> largeElementsLimit, 479 bool enableDebugInfo, bool prettyDebugInfo, 480 bool printGenericOpForm, bool useLocalScope, 481 bool assumeVerified); 482 483 /// Moves the operation before or after the other operation. 484 void moveAfter(PyOperationBase &other); 485 void moveBefore(PyOperationBase &other); 486 487 /// Each must provide access to the raw Operation. 488 virtual PyOperation &getOperation() = 0; 489 }; 490 491 /// Wrapper around PyOperation. 492 /// Operations exist in either an attached (dependent) or detached (top-level) 493 /// state. In the detached state (as on creation), an operation is owned by 494 /// the creator and its lifetime extends either until its reference count 495 /// drops to zero or it is attached to a parent, at which point its lifetime 496 /// is bounded by its top-level parent reference. 497 class PyOperation; 498 using PyOperationRef = PyObjectRef<PyOperation>; 499 class PyOperation : public PyOperationBase, public BaseContextObject { 500 public: 501 ~PyOperation() override; 502 PyOperation &getOperation() override { return *this; } 503 504 /// Returns a PyOperation for the given MlirOperation, optionally associating 505 /// it with a parentKeepAlive. 506 static PyOperationRef 507 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 508 pybind11::object parentKeepAlive = pybind11::object()); 509 510 /// Creates a detached operation. The operation must not be associated with 511 /// any existing live operation. 512 static PyOperationRef 513 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 514 pybind11::object parentKeepAlive = pybind11::object()); 515 516 /// Detaches the operation from its parent block and updates its state 517 /// accordingly. 518 void detachFromParent() { 519 mlirOperationRemoveFromParent(getOperation()); 520 setDetached(); 521 parentKeepAlive = pybind11::object(); 522 } 523 524 /// Gets the backing operation. 525 operator MlirOperation() const { return get(); } 526 MlirOperation get() const { 527 checkValid(); 528 return operation; 529 } 530 531 PyOperationRef getRef() { 532 return PyOperationRef( 533 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 534 } 535 536 bool isAttached() { return attached; } 537 void setAttached(const pybind11::object &parent = pybind11::object()) { 538 assert(!attached && "operation already attached"); 539 attached = true; 540 } 541 void setDetached() { 542 assert(attached && "operation already detached"); 543 attached = false; 544 } 545 void checkValid() const; 546 547 /// Gets the owning block or raises an exception if the operation has no 548 /// owning block. 549 PyBlock getBlock(); 550 551 /// Gets the parent operation or raises an exception if the operation has 552 /// no parent. 553 llvm::Optional<PyOperationRef> getParentOperation(); 554 555 /// Gets a capsule wrapping the void* within the MlirOperation. 556 pybind11::object getCapsule(); 557 558 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 559 /// Ownership of the underlying MlirOperation is taken by calling this 560 /// function. 561 static pybind11::object createFromCapsule(pybind11::object capsule); 562 563 /// Creates an operation. See corresponding python docstring. 564 static pybind11::object 565 create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 566 llvm::Optional<std::vector<PyValue *>> operands, 567 llvm::Optional<pybind11::dict> attributes, 568 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 569 DefaultingPyLocation location, const pybind11::object &ip); 570 571 /// Creates an OpView suitable for this operation. 572 pybind11::object createOpView(); 573 574 /// Erases the underlying MlirOperation, removes its pointer from the 575 /// parent context's live operations map, and sets the valid bit false. 576 void erase(); 577 578 private: 579 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 580 static PyOperationRef createInstance(PyMlirContextRef contextRef, 581 MlirOperation operation, 582 pybind11::object parentKeepAlive); 583 584 MlirOperation operation; 585 pybind11::handle handle; 586 // Keeps the parent alive, regardless of whether it is an Operation or 587 // Module. 588 // TODO: As implemented, this facility is only sufficient for modeling the 589 // trivial module parent back-reference. Generalize this to also account for 590 // transitions from detached to attached and address TODOs in the 591 // ir_operation.py regarding testing corresponding lifetime guarantees. 592 pybind11::object parentKeepAlive; 593 bool attached = true; 594 bool valid = true; 595 596 friend class PyOperationBase; 597 friend class PySymbolTable; 598 }; 599 600 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 601 /// providing more instance-specific accessors and serve as the base class for 602 /// custom ODS-style operation classes. Since this class is subclass on the 603 /// python side, it must present an __init__ method that operates in pure 604 /// python types. 605 class PyOpView : public PyOperationBase { 606 public: 607 PyOpView(const pybind11::object &operationObject); 608 PyOperation &getOperation() override { return operation; } 609 610 static pybind11::object createRawSubclass(const pybind11::object &userClass); 611 612 pybind11::object getOperationObject() { return operationObject; } 613 614 static pybind11::object 615 buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 616 pybind11::list operandList, 617 llvm::Optional<pybind11::dict> attributes, 618 llvm::Optional<std::vector<PyBlock *>> successors, 619 llvm::Optional<int> regions, DefaultingPyLocation location, 620 const pybind11::object &maybeIp); 621 622 private: 623 PyOperation &operation; // For efficient, cast-free access from C++ 624 pybind11::object operationObject; // Holds the reference. 625 }; 626 627 /// Wrapper around an MlirRegion. 628 /// Regions are managed completely by their containing operation. Unlike the 629 /// C++ API, the python API does not support detached regions. 630 class PyRegion { 631 public: 632 PyRegion(PyOperationRef parentOperation, MlirRegion region) 633 : parentOperation(std::move(parentOperation)), region(region) { 634 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 635 } 636 operator MlirRegion() const { return region; } 637 638 MlirRegion get() { return region; } 639 PyOperationRef &getParentOperation() { return parentOperation; } 640 641 void checkValid() { return parentOperation->checkValid(); } 642 643 private: 644 PyOperationRef parentOperation; 645 MlirRegion region; 646 }; 647 648 /// Wrapper around an MlirBlock. 649 /// Blocks are managed completely by their containing operation. Unlike the 650 /// C++ API, the python API does not support detached blocks. 651 class PyBlock { 652 public: 653 PyBlock(PyOperationRef parentOperation, MlirBlock block) 654 : parentOperation(std::move(parentOperation)), block(block) { 655 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 656 } 657 658 MlirBlock get() { return block; } 659 PyOperationRef &getParentOperation() { return parentOperation; } 660 661 void checkValid() { return parentOperation->checkValid(); } 662 663 private: 664 PyOperationRef parentOperation; 665 MlirBlock block; 666 }; 667 668 /// An insertion point maintains a pointer to a Block and a reference operation. 669 /// Calls to insert() will insert a new operation before the 670 /// reference operation. If the reference operation is null, then appends to 671 /// the end of the block. 672 class PyInsertionPoint { 673 public: 674 /// Creates an insertion point positioned after the last operation in the 675 /// block, but still inside the block. 676 PyInsertionPoint(PyBlock &block); 677 /// Creates an insertion point positioned before a reference operation. 678 PyInsertionPoint(PyOperationBase &beforeOperationBase); 679 680 /// Shortcut to create an insertion point at the beginning of the block. 681 static PyInsertionPoint atBlockBegin(PyBlock &block); 682 /// Shortcut to create an insertion point before the block terminator. 683 static PyInsertionPoint atBlockTerminator(PyBlock &block); 684 685 /// Inserts an operation. 686 void insert(PyOperationBase &operationBase); 687 688 /// Enter and exit the context manager. 689 pybind11::object contextEnter(); 690 void contextExit(const pybind11::object &excType, 691 const pybind11::object &excVal, 692 const pybind11::object &excTb); 693 694 PyBlock &getBlock() { return block; } 695 696 private: 697 // Trampoline constructor that avoids null initializing members while 698 // looking up parents. 699 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 700 : refOperation(std::move(refOperation)), block(std::move(block)) {} 701 702 llvm::Optional<PyOperationRef> refOperation; 703 PyBlock block; 704 }; 705 /// Wrapper around the generic MlirType. 706 /// The lifetime of a type is bound by the PyContext that created it. 707 class PyType : public BaseContextObject { 708 public: 709 PyType(PyMlirContextRef contextRef, MlirType type) 710 : BaseContextObject(std::move(contextRef)), type(type) {} 711 bool operator==(const PyType &other); 712 operator MlirType() const { return type; } 713 MlirType get() const { return type; } 714 715 /// Gets a capsule wrapping the void* within the MlirType. 716 pybind11::object getCapsule(); 717 718 /// Creates a PyType from the MlirType wrapped by a capsule. 719 /// Note that PyType instances are uniqued, so the returned object 720 /// may be a pre-existing object. Ownership of the underlying MlirType 721 /// is taken by calling this function. 722 static PyType createFromCapsule(pybind11::object capsule); 723 724 private: 725 MlirType type; 726 }; 727 728 /// CRTP base classes for Python types that subclass Type and should be 729 /// castable from it (i.e. via something like IntegerType(t)). 730 /// By default, type class hierarchies are one level deep (i.e. a 731 /// concrete type class extends PyType); however, intermediate python-visible 732 /// base classes can be modeled by specifying a BaseTy. 733 template <typename DerivedTy, typename BaseTy = PyType> 734 class PyConcreteType : public BaseTy { 735 public: 736 // Derived classes must define statics for: 737 // IsAFunctionTy isaFunction 738 // const char *pyClassName 739 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 740 using IsAFunctionTy = bool (*)(MlirType); 741 742 PyConcreteType() = default; 743 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 744 : BaseTy(std::move(contextRef), t) {} 745 PyConcreteType(PyType &orig) 746 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 747 748 static MlirType castFrom(PyType &orig) { 749 if (!DerivedTy::isaFunction(orig)) { 750 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 751 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 752 DerivedTy::pyClassName + 753 " (from " + origRepr + ")"); 754 } 755 return orig; 756 } 757 758 static void bind(pybind11::module &m) { 759 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 760 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 761 pybind11::arg("cast_from_type")); 762 cls.def_static( 763 "isinstance", 764 [](PyType &otherType) -> bool { 765 return DerivedTy::isaFunction(otherType); 766 }, 767 pybind11::arg("other")); 768 DerivedTy::bindDerived(cls); 769 } 770 771 /// Implemented by derived classes to add methods to the Python subclass. 772 static void bindDerived(ClassTy &m) {} 773 }; 774 775 /// Wrapper around the generic MlirAttribute. 776 /// The lifetime of a type is bound by the PyContext that created it. 777 class PyAttribute : public BaseContextObject { 778 public: 779 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 780 : BaseContextObject(std::move(contextRef)), attr(attr) {} 781 bool operator==(const PyAttribute &other); 782 operator MlirAttribute() const { return attr; } 783 MlirAttribute get() const { return attr; } 784 785 /// Gets a capsule wrapping the void* within the MlirAttribute. 786 pybind11::object getCapsule(); 787 788 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 789 /// Note that PyAttribute instances are uniqued, so the returned object 790 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 791 /// is taken by calling this function. 792 static PyAttribute createFromCapsule(pybind11::object capsule); 793 794 private: 795 MlirAttribute attr; 796 }; 797 798 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 799 /// TODO: Refactor this and the C-API to be based on an Identifier owned 800 /// by the context so as to avoid ownership issues here. 801 class PyNamedAttribute { 802 public: 803 /// Constructs a PyNamedAttr that retains an owned name. This should be 804 /// used in any code that originates an MlirNamedAttribute from a python 805 /// string. 806 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 807 /// passed attribute. 808 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 809 810 MlirNamedAttribute namedAttr; 811 812 private: 813 // Since the MlirNamedAttr contains an internal pointer to the actual 814 // memory of the owned string, it must be heap allocated to remain valid. 815 // Otherwise, strings that fit within the small object optimization threshold 816 // will have their memory address change as the containing object is moved, 817 // resulting in an invalid aliased pointer. 818 std::unique_ptr<std::string> ownedName; 819 }; 820 821 /// CRTP base classes for Python attributes that subclass Attribute and should 822 /// be castable from it (i.e. via something like StringAttr(attr)). 823 /// By default, attribute class hierarchies are one level deep (i.e. a 824 /// concrete attribute class extends PyAttribute); however, intermediate 825 /// python-visible base classes can be modeled by specifying a BaseTy. 826 template <typename DerivedTy, typename BaseTy = PyAttribute> 827 class PyConcreteAttribute : public BaseTy { 828 public: 829 // Derived classes must define statics for: 830 // IsAFunctionTy isaFunction 831 // const char *pyClassName 832 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 833 using IsAFunctionTy = bool (*)(MlirAttribute); 834 835 PyConcreteAttribute() = default; 836 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 837 : BaseTy(std::move(contextRef), attr) {} 838 PyConcreteAttribute(PyAttribute &orig) 839 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 840 841 static MlirAttribute castFrom(PyAttribute &orig) { 842 if (!DerivedTy::isaFunction(orig)) { 843 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 844 throw SetPyError(PyExc_ValueError, 845 llvm::Twine("Cannot cast attribute to ") + 846 DerivedTy::pyClassName + " (from " + origRepr + ")"); 847 } 848 return orig; 849 } 850 851 static void bind(pybind11::module &m) { 852 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 853 pybind11::module_local()); 854 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 855 pybind11::arg("cast_from_attr")); 856 cls.def_static( 857 "isinstance", 858 [](PyAttribute &otherAttr) -> bool { 859 return DerivedTy::isaFunction(otherAttr); 860 }, 861 pybind11::arg("other")); 862 cls.def_property_readonly("type", [](PyAttribute &attr) { 863 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 864 }); 865 DerivedTy::bindDerived(cls); 866 } 867 868 /// Implemented by derived classes to add methods to the Python subclass. 869 static void bindDerived(ClassTy &m) {} 870 }; 871 872 /// Wrapper around the generic MlirValue. 873 /// Values are managed completely by the operation that resulted in their 874 /// definition. For op result value, this is the operation that defines the 875 /// value. For block argument values, this is the operation that contains the 876 /// block to which the value is an argument (blocks cannot be detached in Python 877 /// bindings so such operation always exists). 878 class PyValue { 879 public: 880 PyValue(PyOperationRef parentOperation, MlirValue value) 881 : parentOperation(std::move(parentOperation)), value(value) {} 882 operator MlirValue() const { return value; } 883 884 MlirValue get() { return value; } 885 PyOperationRef &getParentOperation() { return parentOperation; } 886 887 void checkValid() { return parentOperation->checkValid(); } 888 889 /// Gets a capsule wrapping the void* within the MlirValue. 890 pybind11::object getCapsule(); 891 892 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 893 /// the underlying MlirValue is still tied to the owning operation. 894 static PyValue createFromCapsule(pybind11::object capsule); 895 896 private: 897 PyOperationRef parentOperation; 898 MlirValue value; 899 }; 900 901 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 902 class PyAffineExpr : public BaseContextObject { 903 public: 904 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 905 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 906 bool operator==(const PyAffineExpr &other); 907 operator MlirAffineExpr() const { return affineExpr; } 908 MlirAffineExpr get() const { return affineExpr; } 909 910 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 911 pybind11::object getCapsule(); 912 913 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 914 /// Note that PyAffineExpr instances are uniqued, so the returned object 915 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 916 /// is taken by calling this function. 917 static PyAffineExpr createFromCapsule(pybind11::object capsule); 918 919 PyAffineExpr add(const PyAffineExpr &other) const; 920 PyAffineExpr mul(const PyAffineExpr &other) const; 921 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 922 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 923 PyAffineExpr mod(const PyAffineExpr &other) const; 924 925 private: 926 MlirAffineExpr affineExpr; 927 }; 928 929 class PyAffineMap : public BaseContextObject { 930 public: 931 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 932 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 933 bool operator==(const PyAffineMap &other); 934 operator MlirAffineMap() const { return affineMap; } 935 MlirAffineMap get() const { return affineMap; } 936 937 /// Gets a capsule wrapping the void* within the MlirAffineMap. 938 pybind11::object getCapsule(); 939 940 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 941 /// Note that PyAffineMap instances are uniqued, so the returned object 942 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 943 /// is taken by calling this function. 944 static PyAffineMap createFromCapsule(pybind11::object capsule); 945 946 private: 947 MlirAffineMap affineMap; 948 }; 949 950 class PyIntegerSet : public BaseContextObject { 951 public: 952 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 953 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 954 bool operator==(const PyIntegerSet &other); 955 operator MlirIntegerSet() const { return integerSet; } 956 MlirIntegerSet get() const { return integerSet; } 957 958 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 959 pybind11::object getCapsule(); 960 961 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 962 /// Note that PyIntegerSet instances may be uniqued, so the returned object 963 /// may be a pre-existing object. Integer sets are owned by the context. 964 static PyIntegerSet createFromCapsule(pybind11::object capsule); 965 966 private: 967 MlirIntegerSet integerSet; 968 }; 969 970 /// Bindings for MLIR symbol tables. 971 class PySymbolTable { 972 public: 973 /// Constructs a symbol table for the given operation. 974 explicit PySymbolTable(PyOperationBase &operation); 975 976 /// Destroys the symbol table. 977 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 978 979 /// Returns the symbol (opview) with the given name, throws if there is no 980 /// such symbol in the table. 981 pybind11::object dunderGetItem(const std::string &name); 982 983 /// Removes the given operation from the symbol table and erases it. 984 void erase(PyOperationBase &symbol); 985 986 /// Removes the operation with the given name from the symbol table and erases 987 /// it, throws if there is no such symbol in the table. 988 void dunderDel(const std::string &name); 989 990 /// Inserts the given operation into the symbol table. The operation must have 991 /// the symbol trait. 992 PyAttribute insert(PyOperationBase &symbol); 993 994 /// Gets and sets the name of a symbol op. 995 static PyAttribute getSymbolName(PyOperationBase &symbol); 996 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 997 998 /// Gets and sets the visibility of a symbol op. 999 static PyAttribute getVisibility(PyOperationBase &symbol); 1000 static void setVisibility(PyOperationBase &symbol, 1001 const std::string &visibility); 1002 1003 /// Replaces all symbol uses within an operation. See the API 1004 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1005 static void replaceAllSymbolUses(const std::string &oldSymbol, 1006 const std::string &newSymbol, 1007 PyOperationBase &from); 1008 1009 /// Walks all symbol tables under and including 'from'. 1010 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1011 pybind11::object callback); 1012 1013 /// Casts the bindings class into the C API structure. 1014 operator MlirSymbolTable() { return symbolTable; } 1015 1016 private: 1017 PyOperationRef operation; 1018 MlirSymbolTable symbolTable; 1019 }; 1020 1021 void populateIRAffine(pybind11::module &m); 1022 void populateIRAttributes(pybind11::module &m); 1023 void populateIRCore(pybind11::module &m); 1024 void populateIRInterfaces(pybind11::module &m); 1025 void populateIRTypes(pybind11::module &m); 1026 1027 } // namespace python 1028 } // namespace mlir 1029 1030 namespace pybind11 { 1031 namespace detail { 1032 1033 template <> 1034 struct type_caster<mlir::python::DefaultingPyMlirContext> 1035 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1036 template <> 1037 struct type_caster<mlir::python::DefaultingPyLocation> 1038 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1039 1040 } // namespace detail 1041 } // namespace pybind11 1042 1043 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1044