1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <utility> 13 #include <vector> 14 15 #include "PybindUtils.h" 16 17 #include "mlir-c/AffineExpr.h" 18 #include "mlir-c/AffineMap.h" 19 #include "mlir-c/Diagnostics.h" 20 #include "mlir-c/IR.h" 21 #include "mlir-c/IntegerSet.h" 22 #include "llvm/ADT/DenseMap.h" 23 #include "llvm/ADT/Optional.h" 24 25 namespace mlir { 26 namespace python { 27 28 class PyBlock; 29 class PyDiagnostic; 30 class PyDiagnosticHandler; 31 class PyInsertionPoint; 32 class PyLocation; 33 class DefaultingPyLocation; 34 class PyMlirContext; 35 class DefaultingPyMlirContext; 36 class PyModule; 37 class PyOperation; 38 class PyType; 39 class PySymbolTable; 40 class PyValue; 41 42 /// Template for a reference to a concrete type which captures a python 43 /// reference to its underlying python object. 44 template <typename T> 45 class PyObjectRef { 46 public: 47 PyObjectRef(T *referrent, pybind11::object object) 48 : referrent(referrent), object(std::move(object)) { 49 assert(this->referrent && 50 "cannot construct PyObjectRef with null referrent"); 51 assert(this->object && "cannot construct PyObjectRef with null object"); 52 } 53 PyObjectRef(PyObjectRef &&other) 54 : referrent(other.referrent), object(std::move(other.object)) { 55 other.referrent = nullptr; 56 assert(!other.object); 57 } 58 PyObjectRef(const PyObjectRef &other) 59 : referrent(other.referrent), object(other.object /* copies */) {} 60 ~PyObjectRef() = default; 61 62 int getRefCount() { 63 if (!object) 64 return 0; 65 return object.ref_count(); 66 } 67 68 /// Releases the object held by this instance, returning it. 69 /// This is the proper thing to return from a function that wants to return 70 /// the reference. Note that this does not work from initializers. 71 pybind11::object releaseObject() { 72 assert(referrent && object); 73 referrent = nullptr; 74 auto stolen = std::move(object); 75 return stolen; 76 } 77 78 T *get() { return referrent; } 79 T *operator->() { 80 assert(referrent && object); 81 return referrent; 82 } 83 pybind11::object getObject() { 84 assert(referrent && object); 85 return object; 86 } 87 operator bool() const { return referrent && object; } 88 89 private: 90 T *referrent; 91 pybind11::object object; 92 }; 93 94 /// Tracks an entry in the thread context stack. New entries are pushed onto 95 /// here for each with block that activates a new InsertionPoint, Context or 96 /// Location. 97 /// 98 /// Pushing either a Location or InsertionPoint also pushes its associated 99 /// Context. Pushing a Context will not modify the Location or InsertionPoint 100 /// unless if they are from a different context, in which case, they are 101 /// cleared. 102 class PyThreadContextEntry { 103 public: 104 enum class FrameKind { 105 Context, 106 InsertionPoint, 107 Location, 108 }; 109 110 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 111 pybind11::object insertionPoint, 112 pybind11::object location) 113 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 114 location(std::move(location)), frameKind(frameKind) {} 115 116 /// Gets the top of stack context and return nullptr if not defined. 117 static PyMlirContext *getDefaultContext(); 118 119 /// Gets the top of stack insertion point and return nullptr if not defined. 120 static PyInsertionPoint *getDefaultInsertionPoint(); 121 122 /// Gets the top of stack location and returns nullptr if not defined. 123 static PyLocation *getDefaultLocation(); 124 125 PyMlirContext *getContext(); 126 PyInsertionPoint *getInsertionPoint(); 127 PyLocation *getLocation(); 128 FrameKind getFrameKind() { return frameKind; } 129 130 /// Stack management. 131 static PyThreadContextEntry *getTopOfStack(); 132 static pybind11::object pushContext(PyMlirContext &context); 133 static void popContext(PyMlirContext &context); 134 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 135 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 136 static pybind11::object pushLocation(PyLocation &location); 137 static void popLocation(PyLocation &location); 138 139 /// Gets the thread local stack. 140 static std::vector<PyThreadContextEntry> &getStack(); 141 142 private: 143 static void push(FrameKind frameKind, pybind11::object context, 144 pybind11::object insertionPoint, pybind11::object location); 145 146 /// An object reference to the PyContext. 147 pybind11::object context; 148 /// An object reference to the current insertion point. 149 pybind11::object insertionPoint; 150 /// An object reference to the current location. 151 pybind11::object location; 152 // The kind of push that was performed. 153 FrameKind frameKind; 154 }; 155 156 /// Wrapper around MlirContext. 157 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 158 class PyMlirContext { 159 public: 160 PyMlirContext() = delete; 161 PyMlirContext(const PyMlirContext &) = delete; 162 PyMlirContext(PyMlirContext &&) = delete; 163 164 /// For the case of a python __init__ (py::init) method, pybind11 is quite 165 /// strict about needing to return a pointer that is not yet associated to 166 /// an py::object. Since the forContext() method acts like a pool, possibly 167 /// returning a recycled context, it does not satisfy this need. The usual 168 /// way in python to accomplish such a thing is to override __new__, but 169 /// that is also not supported by pybind11. Instead, we use this entry 170 /// point which always constructs a fresh context (which cannot alias an 171 /// existing one because it is fresh). 172 static PyMlirContext *createNewContextForInit(); 173 174 /// Returns a context reference for the singleton PyMlirContext wrapper for 175 /// the given context. 176 static PyMlirContextRef forContext(MlirContext context); 177 ~PyMlirContext(); 178 179 /// Accesses the underlying MlirContext. 180 MlirContext get() { return context; } 181 182 /// Gets a strong reference to this context, which will ensure it is kept 183 /// alive for the life of the reference. 184 PyMlirContextRef getRef() { 185 return PyMlirContextRef(this, pybind11::cast(this)); 186 } 187 188 /// Gets a capsule wrapping the void* within the MlirContext. 189 pybind11::object getCapsule(); 190 191 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 192 /// Note that PyMlirContext instances are uniqued, so the returned object 193 /// may be a pre-existing object. Ownership of the underlying MlirContext 194 /// is taken by calling this function. 195 static pybind11::object createFromCapsule(pybind11::object capsule); 196 197 /// Gets the count of live context objects. Used for testing. 198 static size_t getLiveCount(); 199 200 /// Gets the count of live operations associated with this context. 201 /// Used for testing. 202 size_t getLiveOperationCount(); 203 204 /// Gets the count of live modules associated with this context. 205 /// Used for testing. 206 size_t getLiveModuleCount(); 207 208 /// Enter and exit the context manager. 209 pybind11::object contextEnter(); 210 void contextExit(const pybind11::object &excType, 211 const pybind11::object &excVal, 212 const pybind11::object &excTb); 213 214 /// Attaches a Python callback as a diagnostic handler, returning a 215 /// registration object (internally a PyDiagnosticHandler). 216 pybind11::object attachDiagnosticHandler(pybind11::object callback); 217 218 private: 219 PyMlirContext(MlirContext context); 220 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 221 // preserving the relationship that an MlirContext maps to a single 222 // PyMlirContext wrapper. This could be replaced in the future with an 223 // extension mechanism on the MlirContext for stashing user pointers. 224 // Note that this holds a handle, which does not imply ownership. 225 // Mappings will be removed when the context is destructed. 226 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 227 static LiveContextMap &getLiveContexts(); 228 229 // Interns all live modules associated with this context. Modules tracked 230 // in this map are valid. When a module is invalidated, it is removed 231 // from this map, and while it still exists as an instance, any 232 // attempt to access it will raise an error. 233 using LiveModuleMap = 234 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 235 LiveModuleMap liveModules; 236 237 // Interns all live operations associated with this context. Operations 238 // tracked in this map are valid. When an operation is invalidated, it is 239 // removed from this map, and while it still exists as an instance, any 240 // attempt to access it will raise an error. 241 using LiveOperationMap = 242 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 243 LiveOperationMap liveOperations; 244 245 MlirContext context; 246 friend class PyModule; 247 friend class PyOperation; 248 }; 249 250 /// Used in function arguments when None should resolve to the current context 251 /// manager set instance. 252 class DefaultingPyMlirContext 253 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 254 public: 255 using Defaulting::Defaulting; 256 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 257 static PyMlirContext &resolve(); 258 }; 259 260 /// Base class for all objects that directly or indirectly depend on an 261 /// MlirContext. The lifetime of the context will extend at least to the 262 /// lifetime of these instances. 263 /// Immutable objects that depend on a context extend this directly. 264 class BaseContextObject { 265 public: 266 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 267 assert(this->contextRef && 268 "context object constructed with null context ref"); 269 } 270 271 /// Accesses the context reference. 272 PyMlirContextRef &getContext() { return contextRef; } 273 274 private: 275 PyMlirContextRef contextRef; 276 }; 277 278 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs 279 /// are only valid for the duration of a diagnostic callback and attempting 280 /// to access them outside of that will raise an exception. This applies to 281 /// nested diagnostics (in the notes) as well. 282 class PyDiagnostic { 283 public: 284 PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} 285 void invalidate(); 286 bool isValid() { return valid; } 287 MlirDiagnosticSeverity getSeverity(); 288 PyLocation getLocation(); 289 pybind11::str getMessage(); 290 pybind11::tuple getNotes(); 291 292 private: 293 MlirDiagnostic diagnostic; 294 295 void checkValid(); 296 /// If notes have been materialized from the diagnostic, then this will 297 /// be populated with the corresponding objects (all castable to 298 /// PyDiagnostic). 299 llvm::Optional<pybind11::tuple> materializedNotes; 300 bool valid = true; 301 }; 302 303 /// Represents a diagnostic handler attached to the context. The handler's 304 /// callback will be invoked with PyDiagnostic instances until the detach() 305 /// method is called or the context is destroyed. A diagnostic handler can be 306 /// the subject of a `with` block, which will detach it when the block exits. 307 /// 308 /// Since diagnostic handlers can call back into Python code which can do 309 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, 310 /// etc), this is generally not deemed to be a great user-level API. Users 311 /// should generally use some form of DiagnosticCollector. If the handler raises 312 /// any exceptions, they will just be emitted to stderr and dropped. 313 /// 314 /// The unique usage of this class means that its lifetime management is 315 /// different from most other parts of the API. Instances are always created 316 /// in an attached state and can transition to a detached state by either: 317 /// a) The context being destroyed and unregistering all handlers. 318 /// b) An explicit call to detach(). 319 /// The object may remain live from a Python perspective for an arbitrary time 320 /// after detachment, but there is nothing the user can do with it (since there 321 /// is no way to attach an existing handler object). 322 class PyDiagnosticHandler { 323 public: 324 PyDiagnosticHandler(MlirContext context, pybind11::object callback); 325 ~PyDiagnosticHandler(); 326 327 bool isAttached() { return registeredID.hasValue(); } 328 bool getHadError() { return hadError; } 329 330 /// Detaches the handler. Does nothing if not attached. 331 void detach(); 332 333 pybind11::object contextEnter() { return pybind11::cast(this); } 334 void contextExit(const pybind11::object &excType, 335 const pybind11::object &excVal, 336 const pybind11::object &excTb) { 337 detach(); 338 } 339 340 private: 341 MlirContext context; 342 pybind11::object callback; 343 llvm::Optional<MlirDiagnosticHandlerID> registeredID; 344 bool hadError = false; 345 friend class PyMlirContext; 346 }; 347 348 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 349 /// order to differentiate it from the `Dialect` base class which is extended by 350 /// plugins which extend dialect functionality through extension python code. 351 /// This should be seen as the "low-level" object and `Dialect` as the 352 /// high-level, user facing object. 353 class PyDialectDescriptor : public BaseContextObject { 354 public: 355 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 356 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 357 358 MlirDialect get() { return dialect; } 359 360 private: 361 MlirDialect dialect; 362 }; 363 364 /// User-level object for accessing dialects with dotted syntax such as: 365 /// ctx.dialect.std 366 class PyDialects : public BaseContextObject { 367 public: 368 PyDialects(PyMlirContextRef contextRef) 369 : BaseContextObject(std::move(contextRef)) {} 370 371 MlirDialect getDialectForKey(const std::string &key, bool attrError); 372 }; 373 374 /// User-level dialect object. For dialects that have a registered extension, 375 /// this will be the base class of the extension dialect type. For un-extended, 376 /// objects of this type will be returned directly. 377 class PyDialect { 378 public: 379 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 380 381 pybind11::object getDescriptor() { return descriptor; } 382 383 private: 384 pybind11::object descriptor; 385 }; 386 387 /// Wrapper around an MlirLocation. 388 class PyLocation : public BaseContextObject { 389 public: 390 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 391 : BaseContextObject(std::move(contextRef)), loc(loc) {} 392 393 operator MlirLocation() const { return loc; } 394 MlirLocation get() const { return loc; } 395 396 /// Enter and exit the context manager. 397 pybind11::object contextEnter(); 398 void contextExit(const pybind11::object &excType, 399 const pybind11::object &excVal, 400 const pybind11::object &excTb); 401 402 /// Gets a capsule wrapping the void* within the MlirLocation. 403 pybind11::object getCapsule(); 404 405 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 406 /// Note that PyLocation instances are uniqued, so the returned object 407 /// may be a pre-existing object. Ownership of the underlying MlirLocation 408 /// is taken by calling this function. 409 static PyLocation createFromCapsule(pybind11::object capsule); 410 411 private: 412 MlirLocation loc; 413 }; 414 415 /// Used in function arguments when None should resolve to the current context 416 /// manager set instance. 417 class DefaultingPyLocation 418 : public Defaulting<DefaultingPyLocation, PyLocation> { 419 public: 420 using Defaulting::Defaulting; 421 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 422 static PyLocation &resolve(); 423 424 operator MlirLocation() const { return *get(); } 425 }; 426 427 /// Wrapper around MlirModule. 428 /// This is the top-level, user-owned object that contains regions/ops/blocks. 429 class PyModule; 430 using PyModuleRef = PyObjectRef<PyModule>; 431 class PyModule : public BaseContextObject { 432 public: 433 /// Returns a PyModule reference for the given MlirModule. This may return 434 /// a pre-existing or new object. 435 static PyModuleRef forModule(MlirModule module); 436 PyModule(PyModule &) = delete; 437 PyModule(PyMlirContext &&) = delete; 438 ~PyModule(); 439 440 /// Gets the backing MlirModule. 441 MlirModule get() { return module; } 442 443 /// Gets a strong reference to this module. 444 PyModuleRef getRef() { 445 return PyModuleRef(this, 446 pybind11::reinterpret_borrow<pybind11::object>(handle)); 447 } 448 449 /// Gets a capsule wrapping the void* within the MlirModule. 450 /// Note that the module does not (yet) provide a corresponding factory for 451 /// constructing from a capsule as that would require uniquing PyModule 452 /// instances, which is not currently done. 453 pybind11::object getCapsule(); 454 455 /// Creates a PyModule from the MlirModule wrapped by a capsule. 456 /// Note that PyModule instances are uniqued, so the returned object 457 /// may be a pre-existing object. Ownership of the underlying MlirModule 458 /// is taken by calling this function. 459 static pybind11::object createFromCapsule(pybind11::object capsule); 460 461 private: 462 PyModule(PyMlirContextRef contextRef, MlirModule module); 463 MlirModule module; 464 pybind11::handle handle; 465 }; 466 467 /// Base class for PyOperation and PyOpView which exposes the primary, user 468 /// visible methods for manipulating it. 469 class PyOperationBase { 470 public: 471 virtual ~PyOperationBase() = default; 472 /// Implements the bound 'print' method and helps with others. 473 void print(pybind11::object fileObject, bool binary, 474 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 475 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 476 bool assumeVerified); 477 pybind11::object getAsm(bool binary, 478 llvm::Optional<int64_t> largeElementsLimit, 479 bool enableDebugInfo, bool prettyDebugInfo, 480 bool printGenericOpForm, bool useLocalScope, 481 bool assumeVerified); 482 483 /// Moves the operation before or after the other operation. 484 void moveAfter(PyOperationBase &other); 485 void moveBefore(PyOperationBase &other); 486 487 /// Each must provide access to the raw Operation. 488 virtual PyOperation &getOperation() = 0; 489 }; 490 491 /// Wrapper around PyOperation. 492 /// Operations exist in either an attached (dependent) or detached (top-level) 493 /// state. In the detached state (as on creation), an operation is owned by 494 /// the creator and its lifetime extends either until its reference count 495 /// drops to zero or it is attached to a parent, at which point its lifetime 496 /// is bounded by its top-level parent reference. 497 class PyOperation; 498 using PyOperationRef = PyObjectRef<PyOperation>; 499 class PyOperation : public PyOperationBase, public BaseContextObject { 500 public: 501 ~PyOperation() override; 502 PyOperation &getOperation() override { return *this; } 503 504 /// Returns a PyOperation for the given MlirOperation, optionally associating 505 /// it with a parentKeepAlive. 506 static PyOperationRef 507 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 508 pybind11::object parentKeepAlive = pybind11::object()); 509 510 /// Creates a detached operation. The operation must not be associated with 511 /// any existing live operation. 512 static PyOperationRef 513 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 514 pybind11::object parentKeepAlive = pybind11::object()); 515 516 /// Detaches the operation from its parent block and updates its state 517 /// accordingly. 518 void detachFromParent() { 519 mlirOperationRemoveFromParent(getOperation()); 520 setDetached(); 521 parentKeepAlive = pybind11::object(); 522 } 523 524 /// Gets the backing operation. 525 operator MlirOperation() const { return get(); } 526 MlirOperation get() const { 527 checkValid(); 528 return operation; 529 } 530 531 PyOperationRef getRef() { 532 return PyOperationRef( 533 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 534 } 535 536 bool isAttached() { return attached; } 537 void setAttached(const pybind11::object &parent = pybind11::object()) { 538 assert(!attached && "operation already attached"); 539 attached = true; 540 } 541 void setDetached() { 542 assert(attached && "operation already detached"); 543 attached = false; 544 } 545 void checkValid() const; 546 547 /// Gets the owning block or raises an exception if the operation has no 548 /// owning block. 549 PyBlock getBlock(); 550 551 /// Gets the parent operation or raises an exception if the operation has 552 /// no parent. 553 llvm::Optional<PyOperationRef> getParentOperation(); 554 555 /// Gets a capsule wrapping the void* within the MlirOperation. 556 pybind11::object getCapsule(); 557 558 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 559 /// Ownership of the underlying MlirOperation is taken by calling this 560 /// function. 561 static pybind11::object createFromCapsule(pybind11::object capsule); 562 563 /// Creates an operation. See corresponding python docstring. 564 static pybind11::object 565 create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 566 llvm::Optional<std::vector<PyValue *>> operands, 567 llvm::Optional<pybind11::dict> attributes, 568 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 569 DefaultingPyLocation location, const pybind11::object &ip); 570 571 /// Creates an OpView suitable for this operation. 572 pybind11::object createOpView(); 573 574 /// Erases the underlying MlirOperation, removes its pointer from the 575 /// parent context's live operations map, and sets the valid bit false. 576 void erase(); 577 578 /// Clones this operation. 579 pybind11::object clone(const pybind11::object &ip); 580 581 private: 582 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 583 static PyOperationRef createInstance(PyMlirContextRef contextRef, 584 MlirOperation operation, 585 pybind11::object parentKeepAlive); 586 587 MlirOperation operation; 588 pybind11::handle handle; 589 // Keeps the parent alive, regardless of whether it is an Operation or 590 // Module. 591 // TODO: As implemented, this facility is only sufficient for modeling the 592 // trivial module parent back-reference. Generalize this to also account for 593 // transitions from detached to attached and address TODOs in the 594 // ir_operation.py regarding testing corresponding lifetime guarantees. 595 pybind11::object parentKeepAlive; 596 bool attached = true; 597 bool valid = true; 598 599 friend class PyOperationBase; 600 friend class PySymbolTable; 601 }; 602 603 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 604 /// providing more instance-specific accessors and serve as the base class for 605 /// custom ODS-style operation classes. Since this class is subclass on the 606 /// python side, it must present an __init__ method that operates in pure 607 /// python types. 608 class PyOpView : public PyOperationBase { 609 public: 610 PyOpView(const pybind11::object &operationObject); 611 PyOperation &getOperation() override { return operation; } 612 613 static pybind11::object createRawSubclass(const pybind11::object &userClass); 614 615 pybind11::object getOperationObject() { return operationObject; } 616 617 static pybind11::object 618 buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 619 pybind11::list operandList, 620 llvm::Optional<pybind11::dict> attributes, 621 llvm::Optional<std::vector<PyBlock *>> successors, 622 llvm::Optional<int> regions, DefaultingPyLocation location, 623 const pybind11::object &maybeIp); 624 625 private: 626 PyOperation &operation; // For efficient, cast-free access from C++ 627 pybind11::object operationObject; // Holds the reference. 628 }; 629 630 /// Wrapper around an MlirRegion. 631 /// Regions are managed completely by their containing operation. Unlike the 632 /// C++ API, the python API does not support detached regions. 633 class PyRegion { 634 public: 635 PyRegion(PyOperationRef parentOperation, MlirRegion region) 636 : parentOperation(std::move(parentOperation)), region(region) { 637 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 638 } 639 operator MlirRegion() const { return region; } 640 641 MlirRegion get() { return region; } 642 PyOperationRef &getParentOperation() { return parentOperation; } 643 644 void checkValid() { return parentOperation->checkValid(); } 645 646 private: 647 PyOperationRef parentOperation; 648 MlirRegion region; 649 }; 650 651 /// Wrapper around an MlirBlock. 652 /// Blocks are managed completely by their containing operation. Unlike the 653 /// C++ API, the python API does not support detached blocks. 654 class PyBlock { 655 public: 656 PyBlock(PyOperationRef parentOperation, MlirBlock block) 657 : parentOperation(std::move(parentOperation)), block(block) { 658 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 659 } 660 661 MlirBlock get() { return block; } 662 PyOperationRef &getParentOperation() { return parentOperation; } 663 664 void checkValid() { return parentOperation->checkValid(); } 665 666 private: 667 PyOperationRef parentOperation; 668 MlirBlock block; 669 }; 670 671 /// An insertion point maintains a pointer to a Block and a reference operation. 672 /// Calls to insert() will insert a new operation before the 673 /// reference operation. If the reference operation is null, then appends to 674 /// the end of the block. 675 class PyInsertionPoint { 676 public: 677 /// Creates an insertion point positioned after the last operation in the 678 /// block, but still inside the block. 679 PyInsertionPoint(PyBlock &block); 680 /// Creates an insertion point positioned before a reference operation. 681 PyInsertionPoint(PyOperationBase &beforeOperationBase); 682 683 /// Shortcut to create an insertion point at the beginning of the block. 684 static PyInsertionPoint atBlockBegin(PyBlock &block); 685 /// Shortcut to create an insertion point before the block terminator. 686 static PyInsertionPoint atBlockTerminator(PyBlock &block); 687 688 /// Inserts an operation. 689 void insert(PyOperationBase &operationBase); 690 691 /// Enter and exit the context manager. 692 pybind11::object contextEnter(); 693 void contextExit(const pybind11::object &excType, 694 const pybind11::object &excVal, 695 const pybind11::object &excTb); 696 697 PyBlock &getBlock() { return block; } 698 699 private: 700 // Trampoline constructor that avoids null initializing members while 701 // looking up parents. 702 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 703 : refOperation(std::move(refOperation)), block(std::move(block)) {} 704 705 llvm::Optional<PyOperationRef> refOperation; 706 PyBlock block; 707 }; 708 /// Wrapper around the generic MlirType. 709 /// The lifetime of a type is bound by the PyContext that created it. 710 class PyType : public BaseContextObject { 711 public: 712 PyType(PyMlirContextRef contextRef, MlirType type) 713 : BaseContextObject(std::move(contextRef)), type(type) {} 714 bool operator==(const PyType &other); 715 operator MlirType() const { return type; } 716 MlirType get() const { return type; } 717 718 /// Gets a capsule wrapping the void* within the MlirType. 719 pybind11::object getCapsule(); 720 721 /// Creates a PyType from the MlirType wrapped by a capsule. 722 /// Note that PyType instances are uniqued, so the returned object 723 /// may be a pre-existing object. Ownership of the underlying MlirType 724 /// is taken by calling this function. 725 static PyType createFromCapsule(pybind11::object capsule); 726 727 private: 728 MlirType type; 729 }; 730 731 /// CRTP base classes for Python types that subclass Type and should be 732 /// castable from it (i.e. via something like IntegerType(t)). 733 /// By default, type class hierarchies are one level deep (i.e. a 734 /// concrete type class extends PyType); however, intermediate python-visible 735 /// base classes can be modeled by specifying a BaseTy. 736 template <typename DerivedTy, typename BaseTy = PyType> 737 class PyConcreteType : public BaseTy { 738 public: 739 // Derived classes must define statics for: 740 // IsAFunctionTy isaFunction 741 // const char *pyClassName 742 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 743 using IsAFunctionTy = bool (*)(MlirType); 744 745 PyConcreteType() = default; 746 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 747 : BaseTy(std::move(contextRef), t) {} 748 PyConcreteType(PyType &orig) 749 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 750 751 static MlirType castFrom(PyType &orig) { 752 if (!DerivedTy::isaFunction(orig)) { 753 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 754 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 755 DerivedTy::pyClassName + 756 " (from " + origRepr + ")"); 757 } 758 return orig; 759 } 760 761 static void bind(pybind11::module &m) { 762 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 763 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 764 pybind11::arg("cast_from_type")); 765 cls.def_static( 766 "isinstance", 767 [](PyType &otherType) -> bool { 768 return DerivedTy::isaFunction(otherType); 769 }, 770 pybind11::arg("other")); 771 DerivedTy::bindDerived(cls); 772 } 773 774 /// Implemented by derived classes to add methods to the Python subclass. 775 static void bindDerived(ClassTy &m) {} 776 }; 777 778 /// Wrapper around the generic MlirAttribute. 779 /// The lifetime of a type is bound by the PyContext that created it. 780 class PyAttribute : public BaseContextObject { 781 public: 782 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 783 : BaseContextObject(std::move(contextRef)), attr(attr) {} 784 bool operator==(const PyAttribute &other); 785 operator MlirAttribute() const { return attr; } 786 MlirAttribute get() const { return attr; } 787 788 /// Gets a capsule wrapping the void* within the MlirAttribute. 789 pybind11::object getCapsule(); 790 791 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 792 /// Note that PyAttribute instances are uniqued, so the returned object 793 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 794 /// is taken by calling this function. 795 static PyAttribute createFromCapsule(pybind11::object capsule); 796 797 private: 798 MlirAttribute attr; 799 }; 800 801 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 802 /// TODO: Refactor this and the C-API to be based on an Identifier owned 803 /// by the context so as to avoid ownership issues here. 804 class PyNamedAttribute { 805 public: 806 /// Constructs a PyNamedAttr that retains an owned name. This should be 807 /// used in any code that originates an MlirNamedAttribute from a python 808 /// string. 809 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 810 /// passed attribute. 811 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 812 813 MlirNamedAttribute namedAttr; 814 815 private: 816 // Since the MlirNamedAttr contains an internal pointer to the actual 817 // memory of the owned string, it must be heap allocated to remain valid. 818 // Otherwise, strings that fit within the small object optimization threshold 819 // will have their memory address change as the containing object is moved, 820 // resulting in an invalid aliased pointer. 821 std::unique_ptr<std::string> ownedName; 822 }; 823 824 /// CRTP base classes for Python attributes that subclass Attribute and should 825 /// be castable from it (i.e. via something like StringAttr(attr)). 826 /// By default, attribute class hierarchies are one level deep (i.e. a 827 /// concrete attribute class extends PyAttribute); however, intermediate 828 /// python-visible base classes can be modeled by specifying a BaseTy. 829 template <typename DerivedTy, typename BaseTy = PyAttribute> 830 class PyConcreteAttribute : public BaseTy { 831 public: 832 // Derived classes must define statics for: 833 // IsAFunctionTy isaFunction 834 // const char *pyClassName 835 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 836 using IsAFunctionTy = bool (*)(MlirAttribute); 837 838 PyConcreteAttribute() = default; 839 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 840 : BaseTy(std::move(contextRef), attr) {} 841 PyConcreteAttribute(PyAttribute &orig) 842 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 843 844 static MlirAttribute castFrom(PyAttribute &orig) { 845 if (!DerivedTy::isaFunction(orig)) { 846 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 847 throw SetPyError(PyExc_ValueError, 848 llvm::Twine("Cannot cast attribute to ") + 849 DerivedTy::pyClassName + " (from " + origRepr + ")"); 850 } 851 return orig; 852 } 853 854 static void bind(pybind11::module &m) { 855 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 856 pybind11::module_local()); 857 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 858 pybind11::arg("cast_from_attr")); 859 cls.def_static( 860 "isinstance", 861 [](PyAttribute &otherAttr) -> bool { 862 return DerivedTy::isaFunction(otherAttr); 863 }, 864 pybind11::arg("other")); 865 cls.def_property_readonly("type", [](PyAttribute &attr) { 866 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 867 }); 868 DerivedTy::bindDerived(cls); 869 } 870 871 /// Implemented by derived classes to add methods to the Python subclass. 872 static void bindDerived(ClassTy &m) {} 873 }; 874 875 /// Wrapper around the generic MlirValue. 876 /// Values are managed completely by the operation that resulted in their 877 /// definition. For op result value, this is the operation that defines the 878 /// value. For block argument values, this is the operation that contains the 879 /// block to which the value is an argument (blocks cannot be detached in Python 880 /// bindings so such operation always exists). 881 class PyValue { 882 public: 883 PyValue(PyOperationRef parentOperation, MlirValue value) 884 : parentOperation(std::move(parentOperation)), value(value) {} 885 operator MlirValue() const { return value; } 886 887 MlirValue get() { return value; } 888 PyOperationRef &getParentOperation() { return parentOperation; } 889 890 void checkValid() { return parentOperation->checkValid(); } 891 892 /// Gets a capsule wrapping the void* within the MlirValue. 893 pybind11::object getCapsule(); 894 895 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 896 /// the underlying MlirValue is still tied to the owning operation. 897 static PyValue createFromCapsule(pybind11::object capsule); 898 899 private: 900 PyOperationRef parentOperation; 901 MlirValue value; 902 }; 903 904 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 905 class PyAffineExpr : public BaseContextObject { 906 public: 907 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 908 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 909 bool operator==(const PyAffineExpr &other); 910 operator MlirAffineExpr() const { return affineExpr; } 911 MlirAffineExpr get() const { return affineExpr; } 912 913 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 914 pybind11::object getCapsule(); 915 916 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 917 /// Note that PyAffineExpr instances are uniqued, so the returned object 918 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 919 /// is taken by calling this function. 920 static PyAffineExpr createFromCapsule(pybind11::object capsule); 921 922 PyAffineExpr add(const PyAffineExpr &other) const; 923 PyAffineExpr mul(const PyAffineExpr &other) const; 924 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 925 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 926 PyAffineExpr mod(const PyAffineExpr &other) const; 927 928 private: 929 MlirAffineExpr affineExpr; 930 }; 931 932 class PyAffineMap : public BaseContextObject { 933 public: 934 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 935 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 936 bool operator==(const PyAffineMap &other); 937 operator MlirAffineMap() const { return affineMap; } 938 MlirAffineMap get() const { return affineMap; } 939 940 /// Gets a capsule wrapping the void* within the MlirAffineMap. 941 pybind11::object getCapsule(); 942 943 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 944 /// Note that PyAffineMap instances are uniqued, so the returned object 945 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 946 /// is taken by calling this function. 947 static PyAffineMap createFromCapsule(pybind11::object capsule); 948 949 private: 950 MlirAffineMap affineMap; 951 }; 952 953 class PyIntegerSet : public BaseContextObject { 954 public: 955 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 956 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 957 bool operator==(const PyIntegerSet &other); 958 operator MlirIntegerSet() const { return integerSet; } 959 MlirIntegerSet get() const { return integerSet; } 960 961 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 962 pybind11::object getCapsule(); 963 964 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 965 /// Note that PyIntegerSet instances may be uniqued, so the returned object 966 /// may be a pre-existing object. Integer sets are owned by the context. 967 static PyIntegerSet createFromCapsule(pybind11::object capsule); 968 969 private: 970 MlirIntegerSet integerSet; 971 }; 972 973 /// Bindings for MLIR symbol tables. 974 class PySymbolTable { 975 public: 976 /// Constructs a symbol table for the given operation. 977 explicit PySymbolTable(PyOperationBase &operation); 978 979 /// Destroys the symbol table. 980 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 981 982 /// Returns the symbol (opview) with the given name, throws if there is no 983 /// such symbol in the table. 984 pybind11::object dunderGetItem(const std::string &name); 985 986 /// Removes the given operation from the symbol table and erases it. 987 void erase(PyOperationBase &symbol); 988 989 /// Removes the operation with the given name from the symbol table and erases 990 /// it, throws if there is no such symbol in the table. 991 void dunderDel(const std::string &name); 992 993 /// Inserts the given operation into the symbol table. The operation must have 994 /// the symbol trait. 995 PyAttribute insert(PyOperationBase &symbol); 996 997 /// Gets and sets the name of a symbol op. 998 static PyAttribute getSymbolName(PyOperationBase &symbol); 999 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 1000 1001 /// Gets and sets the visibility of a symbol op. 1002 static PyAttribute getVisibility(PyOperationBase &symbol); 1003 static void setVisibility(PyOperationBase &symbol, 1004 const std::string &visibility); 1005 1006 /// Replaces all symbol uses within an operation. See the API 1007 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 1008 static void replaceAllSymbolUses(const std::string &oldSymbol, 1009 const std::string &newSymbol, 1010 PyOperationBase &from); 1011 1012 /// Walks all symbol tables under and including 'from'. 1013 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 1014 pybind11::object callback); 1015 1016 /// Casts the bindings class into the C API structure. 1017 operator MlirSymbolTable() { return symbolTable; } 1018 1019 private: 1020 PyOperationRef operation; 1021 MlirSymbolTable symbolTable; 1022 }; 1023 1024 void populateIRAffine(pybind11::module &m); 1025 void populateIRAttributes(pybind11::module &m); 1026 void populateIRCore(pybind11::module &m); 1027 void populateIRInterfaces(pybind11::module &m); 1028 void populateIRTypes(pybind11::module &m); 1029 1030 } // namespace python 1031 } // namespace mlir 1032 1033 namespace pybind11 { 1034 namespace detail { 1035 1036 template <> 1037 struct type_caster<mlir::python::DefaultingPyMlirContext> 1038 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 1039 template <> 1040 struct type_caster<mlir::python::DefaultingPyLocation> 1041 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 1042 1043 } // namespace detail 1044 } // namespace pybind11 1045 1046 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 1047