1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <vector> 13 14 #include "PybindUtils.h" 15 16 #include "mlir-c/AffineExpr.h" 17 #include "mlir-c/AffineMap.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/IntegerSet.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/Optional.h" 22 23 namespace mlir { 24 namespace python { 25 26 class PyBlock; 27 class PyInsertionPoint; 28 class PyLocation; 29 class DefaultingPyLocation; 30 class PyMlirContext; 31 class DefaultingPyMlirContext; 32 class PyModule; 33 class PyOperation; 34 class PyType; 35 class PySymbolTable; 36 class PyValue; 37 38 /// Template for a reference to a concrete type which captures a python 39 /// reference to its underlying python object. 40 template <typename T> 41 class PyObjectRef { 42 public: 43 PyObjectRef(T *referrent, pybind11::object object) 44 : referrent(referrent), object(std::move(object)) { 45 assert(this->referrent && 46 "cannot construct PyObjectRef with null referrent"); 47 assert(this->object && "cannot construct PyObjectRef with null object"); 48 } 49 PyObjectRef(PyObjectRef &&other) 50 : referrent(other.referrent), object(std::move(other.object)) { 51 other.referrent = nullptr; 52 assert(!other.object); 53 } 54 PyObjectRef(const PyObjectRef &other) 55 : referrent(other.referrent), object(other.object /* copies */) {} 56 ~PyObjectRef() {} 57 58 int getRefCount() { 59 if (!object) 60 return 0; 61 return object.ref_count(); 62 } 63 64 /// Releases the object held by this instance, returning it. 65 /// This is the proper thing to return from a function that wants to return 66 /// the reference. Note that this does not work from initializers. 67 pybind11::object releaseObject() { 68 assert(referrent && object); 69 referrent = nullptr; 70 auto stolen = std::move(object); 71 return stolen; 72 } 73 74 T *get() { return referrent; } 75 T *operator->() { 76 assert(referrent && object); 77 return referrent; 78 } 79 pybind11::object getObject() { 80 assert(referrent && object); 81 return object; 82 } 83 operator bool() const { return referrent && object; } 84 85 private: 86 T *referrent; 87 pybind11::object object; 88 }; 89 90 /// Tracks an entry in the thread context stack. New entries are pushed onto 91 /// here for each with block that activates a new InsertionPoint, Context or 92 /// Location. 93 /// 94 /// Pushing either a Location or InsertionPoint also pushes its associated 95 /// Context. Pushing a Context will not modify the Location or InsertionPoint 96 /// unless if they are from a different context, in which case, they are 97 /// cleared. 98 class PyThreadContextEntry { 99 public: 100 enum class FrameKind { 101 Context, 102 InsertionPoint, 103 Location, 104 }; 105 106 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 107 pybind11::object insertionPoint, 108 pybind11::object location) 109 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 110 location(std::move(location)), frameKind(frameKind) {} 111 112 /// Gets the top of stack context and return nullptr if not defined. 113 static PyMlirContext *getDefaultContext(); 114 115 /// Gets the top of stack insertion point and return nullptr if not defined. 116 static PyInsertionPoint *getDefaultInsertionPoint(); 117 118 /// Gets the top of stack location and returns nullptr if not defined. 119 static PyLocation *getDefaultLocation(); 120 121 PyMlirContext *getContext(); 122 PyInsertionPoint *getInsertionPoint(); 123 PyLocation *getLocation(); 124 FrameKind getFrameKind() { return frameKind; } 125 126 /// Stack management. 127 static PyThreadContextEntry *getTopOfStack(); 128 static pybind11::object pushContext(PyMlirContext &context); 129 static void popContext(PyMlirContext &context); 130 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 131 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 132 static pybind11::object pushLocation(PyLocation &location); 133 static void popLocation(PyLocation &location); 134 135 /// Gets the thread local stack. 136 static std::vector<PyThreadContextEntry> &getStack(); 137 138 private: 139 static void push(FrameKind frameKind, pybind11::object context, 140 pybind11::object insertionPoint, pybind11::object location); 141 142 /// An object reference to the PyContext. 143 pybind11::object context; 144 /// An object reference to the current insertion point. 145 pybind11::object insertionPoint; 146 /// An object reference to the current location. 147 pybind11::object location; 148 // The kind of push that was performed. 149 FrameKind frameKind; 150 }; 151 152 /// Wrapper around MlirContext. 153 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 154 class PyMlirContext { 155 public: 156 PyMlirContext() = delete; 157 PyMlirContext(const PyMlirContext &) = delete; 158 PyMlirContext(PyMlirContext &&) = delete; 159 160 /// For the case of a python __init__ (py::init) method, pybind11 is quite 161 /// strict about needing to return a pointer that is not yet associated to 162 /// an py::object. Since the forContext() method acts like a pool, possibly 163 /// returning a recycled context, it does not satisfy this need. The usual 164 /// way in python to accomplish such a thing is to override __new__, but 165 /// that is also not supported by pybind11. Instead, we use this entry 166 /// point which always constructs a fresh context (which cannot alias an 167 /// existing one because it is fresh). 168 static PyMlirContext *createNewContextForInit(); 169 170 /// Returns a context reference for the singleton PyMlirContext wrapper for 171 /// the given context. 172 static PyMlirContextRef forContext(MlirContext context); 173 ~PyMlirContext(); 174 175 /// Accesses the underlying MlirContext. 176 MlirContext get() { return context; } 177 178 /// Gets a strong reference to this context, which will ensure it is kept 179 /// alive for the life of the reference. 180 PyMlirContextRef getRef() { 181 return PyMlirContextRef(this, pybind11::cast(this)); 182 } 183 184 /// Gets a capsule wrapping the void* within the MlirContext. 185 pybind11::object getCapsule(); 186 187 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 188 /// Note that PyMlirContext instances are uniqued, so the returned object 189 /// may be a pre-existing object. Ownership of the underlying MlirContext 190 /// is taken by calling this function. 191 static pybind11::object createFromCapsule(pybind11::object capsule); 192 193 /// Gets the count of live context objects. Used for testing. 194 static size_t getLiveCount(); 195 196 /// Gets the count of live operations associated with this context. 197 /// Used for testing. 198 size_t getLiveOperationCount(); 199 200 /// Gets the count of live modules associated with this context. 201 /// Used for testing. 202 size_t getLiveModuleCount(); 203 204 /// Enter and exit the context manager. 205 pybind11::object contextEnter(); 206 void contextExit(const pybind11::object &excType, 207 const pybind11::object &excVal, 208 const pybind11::object &excTb); 209 210 private: 211 PyMlirContext(MlirContext context); 212 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 213 // preserving the relationship that an MlirContext maps to a single 214 // PyMlirContext wrapper. This could be replaced in the future with an 215 // extension mechanism on the MlirContext for stashing user pointers. 216 // Note that this holds a handle, which does not imply ownership. 217 // Mappings will be removed when the context is destructed. 218 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 219 static LiveContextMap &getLiveContexts(); 220 221 // Interns all live modules associated with this context. Modules tracked 222 // in this map are valid. When a module is invalidated, it is removed 223 // from this map, and while it still exists as an instance, any 224 // attempt to access it will raise an error. 225 using LiveModuleMap = 226 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 227 LiveModuleMap liveModules; 228 229 // Interns all live operations associated with this context. Operations 230 // tracked in this map are valid. When an operation is invalidated, it is 231 // removed from this map, and while it still exists as an instance, any 232 // attempt to access it will raise an error. 233 using LiveOperationMap = 234 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 235 LiveOperationMap liveOperations; 236 237 MlirContext context; 238 friend class PyModule; 239 friend class PyOperation; 240 }; 241 242 /// Used in function arguments when None should resolve to the current context 243 /// manager set instance. 244 class DefaultingPyMlirContext 245 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 246 public: 247 using Defaulting::Defaulting; 248 static constexpr const char kTypeDescription[] = "mlir.ir.Context"; 249 static PyMlirContext &resolve(); 250 }; 251 252 /// Base class for all objects that directly or indirectly depend on an 253 /// MlirContext. The lifetime of the context will extend at least to the 254 /// lifetime of these instances. 255 /// Immutable objects that depend on a context extend this directly. 256 class BaseContextObject { 257 public: 258 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 259 assert(this->contextRef && 260 "context object constructed with null context ref"); 261 } 262 263 /// Accesses the context reference. 264 PyMlirContextRef &getContext() { return contextRef; } 265 266 private: 267 PyMlirContextRef contextRef; 268 }; 269 270 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 271 /// order to differentiate it from the `Dialect` base class which is extended by 272 /// plugins which extend dialect functionality through extension python code. 273 /// This should be seen as the "low-level" object and `Dialect` as the 274 /// high-level, user facing object. 275 class PyDialectDescriptor : public BaseContextObject { 276 public: 277 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 278 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 279 280 MlirDialect get() { return dialect; } 281 282 private: 283 MlirDialect dialect; 284 }; 285 286 /// User-level object for accessing dialects with dotted syntax such as: 287 /// ctx.dialect.std 288 class PyDialects : public BaseContextObject { 289 public: 290 PyDialects(PyMlirContextRef contextRef) 291 : BaseContextObject(std::move(contextRef)) {} 292 293 MlirDialect getDialectForKey(const std::string &key, bool attrError); 294 }; 295 296 /// User-level dialect object. For dialects that have a registered extension, 297 /// this will be the base class of the extension dialect type. For un-extended, 298 /// objects of this type will be returned directly. 299 class PyDialect { 300 public: 301 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 302 303 pybind11::object getDescriptor() { return descriptor; } 304 305 private: 306 pybind11::object descriptor; 307 }; 308 309 /// Wrapper around an MlirLocation. 310 class PyLocation : public BaseContextObject { 311 public: 312 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 313 : BaseContextObject(std::move(contextRef)), loc(loc) {} 314 315 operator MlirLocation() const { return loc; } 316 MlirLocation get() const { return loc; } 317 318 /// Enter and exit the context manager. 319 pybind11::object contextEnter(); 320 void contextExit(const pybind11::object &excType, 321 const pybind11::object &excVal, 322 const pybind11::object &excTb); 323 324 /// Gets a capsule wrapping the void* within the MlirLocation. 325 pybind11::object getCapsule(); 326 327 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 328 /// Note that PyLocation instances are uniqued, so the returned object 329 /// may be a pre-existing object. Ownership of the underlying MlirLocation 330 /// is taken by calling this function. 331 static PyLocation createFromCapsule(pybind11::object capsule); 332 333 private: 334 MlirLocation loc; 335 }; 336 337 /// Used in function arguments when None should resolve to the current context 338 /// manager set instance. 339 class DefaultingPyLocation 340 : public Defaulting<DefaultingPyLocation, PyLocation> { 341 public: 342 using Defaulting::Defaulting; 343 static constexpr const char kTypeDescription[] = "mlir.ir.Location"; 344 static PyLocation &resolve(); 345 346 operator MlirLocation() const { return *get(); } 347 }; 348 349 /// Wrapper around MlirModule. 350 /// This is the top-level, user-owned object that contains regions/ops/blocks. 351 class PyModule; 352 using PyModuleRef = PyObjectRef<PyModule>; 353 class PyModule : public BaseContextObject { 354 public: 355 /// Returns a PyModule reference for the given MlirModule. This may return 356 /// a pre-existing or new object. 357 static PyModuleRef forModule(MlirModule module); 358 PyModule(PyModule &) = delete; 359 PyModule(PyMlirContext &&) = delete; 360 ~PyModule(); 361 362 /// Gets the backing MlirModule. 363 MlirModule get() { return module; } 364 365 /// Gets a strong reference to this module. 366 PyModuleRef getRef() { 367 return PyModuleRef(this, 368 pybind11::reinterpret_borrow<pybind11::object>(handle)); 369 } 370 371 /// Gets a capsule wrapping the void* within the MlirModule. 372 /// Note that the module does not (yet) provide a corresponding factory for 373 /// constructing from a capsule as that would require uniquing PyModule 374 /// instances, which is not currently done. 375 pybind11::object getCapsule(); 376 377 /// Creates a PyModule from the MlirModule wrapped by a capsule. 378 /// Note that PyModule instances are uniqued, so the returned object 379 /// may be a pre-existing object. Ownership of the underlying MlirModule 380 /// is taken by calling this function. 381 static pybind11::object createFromCapsule(pybind11::object capsule); 382 383 private: 384 PyModule(PyMlirContextRef contextRef, MlirModule module); 385 MlirModule module; 386 pybind11::handle handle; 387 }; 388 389 /// Base class for PyOperation and PyOpView which exposes the primary, user 390 /// visible methods for manipulating it. 391 class PyOperationBase { 392 public: 393 virtual ~PyOperationBase() = default; 394 /// Implements the bound 'print' method and helps with others. 395 void print(pybind11::object fileObject, bool binary, 396 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 397 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, 398 bool assumeVerified); 399 pybind11::object getAsm(bool binary, 400 llvm::Optional<int64_t> largeElementsLimit, 401 bool enableDebugInfo, bool prettyDebugInfo, 402 bool printGenericOpForm, bool useLocalScope, 403 bool assumeVerified); 404 405 /// Moves the operation before or after the other operation. 406 void moveAfter(PyOperationBase &other); 407 void moveBefore(PyOperationBase &other); 408 409 /// Each must provide access to the raw Operation. 410 virtual PyOperation &getOperation() = 0; 411 }; 412 413 /// Wrapper around PyOperation. 414 /// Operations exist in either an attached (dependent) or detached (top-level) 415 /// state. In the detached state (as on creation), an operation is owned by 416 /// the creator and its lifetime extends either until its reference count 417 /// drops to zero or it is attached to a parent, at which point its lifetime 418 /// is bounded by its top-level parent reference. 419 class PyOperation; 420 using PyOperationRef = PyObjectRef<PyOperation>; 421 class PyOperation : public PyOperationBase, public BaseContextObject { 422 public: 423 ~PyOperation(); 424 PyOperation &getOperation() override { return *this; } 425 426 /// Returns a PyOperation for the given MlirOperation, optionally associating 427 /// it with a parentKeepAlive. 428 static PyOperationRef 429 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 430 pybind11::object parentKeepAlive = pybind11::object()); 431 432 /// Creates a detached operation. The operation must not be associated with 433 /// any existing live operation. 434 static PyOperationRef 435 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 436 pybind11::object parentKeepAlive = pybind11::object()); 437 438 /// Detaches the operation from its parent block and updates its state 439 /// accordingly. 440 void detachFromParent() { 441 mlirOperationRemoveFromParent(getOperation()); 442 setDetached(); 443 parentKeepAlive = pybind11::object(); 444 } 445 446 /// Gets the backing operation. 447 operator MlirOperation() const { return get(); } 448 MlirOperation get() const { 449 checkValid(); 450 return operation; 451 } 452 453 PyOperationRef getRef() { 454 return PyOperationRef( 455 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 456 } 457 458 bool isAttached() { return attached; } 459 void setAttached(pybind11::object parent = pybind11::object()) { 460 assert(!attached && "operation already attached"); 461 attached = true; 462 } 463 void setDetached() { 464 assert(attached && "operation already detached"); 465 attached = false; 466 } 467 void checkValid() const; 468 469 /// Gets the owning block or raises an exception if the operation has no 470 /// owning block. 471 PyBlock getBlock(); 472 473 /// Gets the parent operation or raises an exception if the operation has 474 /// no parent. 475 llvm::Optional<PyOperationRef> getParentOperation(); 476 477 /// Gets a capsule wrapping the void* within the MlirOperation. 478 pybind11::object getCapsule(); 479 480 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 481 /// Ownership of the underlying MlirOperation is taken by calling this 482 /// function. 483 static pybind11::object createFromCapsule(pybind11::object capsule); 484 485 /// Creates an operation. See corresponding python docstring. 486 static pybind11::object 487 create(const std::string &name, llvm::Optional<std::vector<PyType *>> results, 488 llvm::Optional<std::vector<PyValue *>> operands, 489 llvm::Optional<pybind11::dict> attributes, 490 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 491 DefaultingPyLocation location, const pybind11::object &ip); 492 493 /// Creates an OpView suitable for this operation. 494 pybind11::object createOpView(); 495 496 /// Erases the underlying MlirOperation, removes its pointer from the 497 /// parent context's live operations map, and sets the valid bit false. 498 void erase(); 499 500 private: 501 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 502 static PyOperationRef createInstance(PyMlirContextRef contextRef, 503 MlirOperation operation, 504 pybind11::object parentKeepAlive); 505 506 MlirOperation operation; 507 pybind11::handle handle; 508 // Keeps the parent alive, regardless of whether it is an Operation or 509 // Module. 510 // TODO: As implemented, this facility is only sufficient for modeling the 511 // trivial module parent back-reference. Generalize this to also account for 512 // transitions from detached to attached and address TODOs in the 513 // ir_operation.py regarding testing corresponding lifetime guarantees. 514 pybind11::object parentKeepAlive; 515 bool attached = true; 516 bool valid = true; 517 518 friend class PyOperationBase; 519 friend class PySymbolTable; 520 }; 521 522 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 523 /// providing more instance-specific accessors and serve as the base class for 524 /// custom ODS-style operation classes. Since this class is subclass on the 525 /// python side, it must present an __init__ method that operates in pure 526 /// python types. 527 class PyOpView : public PyOperationBase { 528 public: 529 PyOpView(const pybind11::object &operationObject); 530 PyOperation &getOperation() override { return operation; } 531 532 static pybind11::object createRawSubclass(const pybind11::object &userClass); 533 534 pybind11::object getOperationObject() { return operationObject; } 535 536 static pybind11::object 537 buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, 538 pybind11::list operandList, 539 llvm::Optional<pybind11::dict> attributes, 540 llvm::Optional<std::vector<PyBlock *>> successors, 541 llvm::Optional<int> regions, DefaultingPyLocation location, 542 const pybind11::object &maybeIp); 543 544 private: 545 PyOperation &operation; // For efficient, cast-free access from C++ 546 pybind11::object operationObject; // Holds the reference. 547 }; 548 549 /// Wrapper around an MlirRegion. 550 /// Regions are managed completely by their containing operation. Unlike the 551 /// C++ API, the python API does not support detached regions. 552 class PyRegion { 553 public: 554 PyRegion(PyOperationRef parentOperation, MlirRegion region) 555 : parentOperation(std::move(parentOperation)), region(region) { 556 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 557 } 558 operator MlirRegion() const { return region; } 559 560 MlirRegion get() { return region; } 561 PyOperationRef &getParentOperation() { return parentOperation; } 562 563 void checkValid() { return parentOperation->checkValid(); } 564 565 private: 566 PyOperationRef parentOperation; 567 MlirRegion region; 568 }; 569 570 /// Wrapper around an MlirBlock. 571 /// Blocks are managed completely by their containing operation. Unlike the 572 /// C++ API, the python API does not support detached blocks. 573 class PyBlock { 574 public: 575 PyBlock(PyOperationRef parentOperation, MlirBlock block) 576 : parentOperation(std::move(parentOperation)), block(block) { 577 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 578 } 579 580 MlirBlock get() { return block; } 581 PyOperationRef &getParentOperation() { return parentOperation; } 582 583 void checkValid() { return parentOperation->checkValid(); } 584 585 private: 586 PyOperationRef parentOperation; 587 MlirBlock block; 588 }; 589 590 /// An insertion point maintains a pointer to a Block and a reference operation. 591 /// Calls to insert() will insert a new operation before the 592 /// reference operation. If the reference operation is null, then appends to 593 /// the end of the block. 594 class PyInsertionPoint { 595 public: 596 /// Creates an insertion point positioned after the last operation in the 597 /// block, but still inside the block. 598 PyInsertionPoint(PyBlock &block); 599 /// Creates an insertion point positioned before a reference operation. 600 PyInsertionPoint(PyOperationBase &beforeOperationBase); 601 602 /// Shortcut to create an insertion point at the beginning of the block. 603 static PyInsertionPoint atBlockBegin(PyBlock &block); 604 /// Shortcut to create an insertion point before the block terminator. 605 static PyInsertionPoint atBlockTerminator(PyBlock &block); 606 607 /// Inserts an operation. 608 void insert(PyOperationBase &operationBase); 609 610 /// Enter and exit the context manager. 611 pybind11::object contextEnter(); 612 void contextExit(const pybind11::object &excType, 613 const pybind11::object &excVal, 614 const pybind11::object &excTb); 615 616 PyBlock &getBlock() { return block; } 617 618 private: 619 // Trampoline constructor that avoids null initializing members while 620 // looking up parents. 621 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 622 : refOperation(std::move(refOperation)), block(std::move(block)) {} 623 624 llvm::Optional<PyOperationRef> refOperation; 625 PyBlock block; 626 }; 627 /// Wrapper around the generic MlirType. 628 /// The lifetime of a type is bound by the PyContext that created it. 629 class PyType : public BaseContextObject { 630 public: 631 PyType(PyMlirContextRef contextRef, MlirType type) 632 : BaseContextObject(std::move(contextRef)), type(type) {} 633 bool operator==(const PyType &other); 634 operator MlirType() const { return type; } 635 MlirType get() const { return type; } 636 637 /// Gets a capsule wrapping the void* within the MlirType. 638 pybind11::object getCapsule(); 639 640 /// Creates a PyType from the MlirType wrapped by a capsule. 641 /// Note that PyType instances are uniqued, so the returned object 642 /// may be a pre-existing object. Ownership of the underlying MlirType 643 /// is taken by calling this function. 644 static PyType createFromCapsule(pybind11::object capsule); 645 646 private: 647 MlirType type; 648 }; 649 650 /// CRTP base classes for Python types that subclass Type and should be 651 /// castable from it (i.e. via something like IntegerType(t)). 652 /// By default, type class hierarchies are one level deep (i.e. a 653 /// concrete type class extends PyType); however, intermediate python-visible 654 /// base classes can be modeled by specifying a BaseTy. 655 template <typename DerivedTy, typename BaseTy = PyType> 656 class PyConcreteType : public BaseTy { 657 public: 658 // Derived classes must define statics for: 659 // IsAFunctionTy isaFunction 660 // const char *pyClassName 661 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 662 using IsAFunctionTy = bool (*)(MlirType); 663 664 PyConcreteType() = default; 665 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 666 : BaseTy(std::move(contextRef), t) {} 667 PyConcreteType(PyType &orig) 668 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 669 670 static MlirType castFrom(PyType &orig) { 671 if (!DerivedTy::isaFunction(orig)) { 672 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 673 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 674 DerivedTy::pyClassName + 675 " (from " + origRepr + ")"); 676 } 677 return orig; 678 } 679 680 static void bind(pybind11::module &m) { 681 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 682 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(), 683 pybind11::arg("cast_from_type")); 684 cls.def_static( 685 "isinstance", 686 [](PyType &otherType) -> bool { 687 return DerivedTy::isaFunction(otherType); 688 }, 689 pybind11::arg("other")); 690 DerivedTy::bindDerived(cls); 691 } 692 693 /// Implemented by derived classes to add methods to the Python subclass. 694 static void bindDerived(ClassTy &m) {} 695 }; 696 697 /// Wrapper around the generic MlirAttribute. 698 /// The lifetime of a type is bound by the PyContext that created it. 699 class PyAttribute : public BaseContextObject { 700 public: 701 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 702 : BaseContextObject(std::move(contextRef)), attr(attr) {} 703 bool operator==(const PyAttribute &other); 704 operator MlirAttribute() const { return attr; } 705 MlirAttribute get() const { return attr; } 706 707 /// Gets a capsule wrapping the void* within the MlirAttribute. 708 pybind11::object getCapsule(); 709 710 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 711 /// Note that PyAttribute instances are uniqued, so the returned object 712 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 713 /// is taken by calling this function. 714 static PyAttribute createFromCapsule(pybind11::object capsule); 715 716 private: 717 MlirAttribute attr; 718 }; 719 720 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 721 /// TODO: Refactor this and the C-API to be based on an Identifier owned 722 /// by the context so as to avoid ownership issues here. 723 class PyNamedAttribute { 724 public: 725 /// Constructs a PyNamedAttr that retains an owned name. This should be 726 /// used in any code that originates an MlirNamedAttribute from a python 727 /// string. 728 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 729 /// passed attribute. 730 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 731 732 MlirNamedAttribute namedAttr; 733 734 private: 735 // Since the MlirNamedAttr contains an internal pointer to the actual 736 // memory of the owned string, it must be heap allocated to remain valid. 737 // Otherwise, strings that fit within the small object optimization threshold 738 // will have their memory address change as the containing object is moved, 739 // resulting in an invalid aliased pointer. 740 std::unique_ptr<std::string> ownedName; 741 }; 742 743 /// CRTP base classes for Python attributes that subclass Attribute and should 744 /// be castable from it (i.e. via something like StringAttr(attr)). 745 /// By default, attribute class hierarchies are one level deep (i.e. a 746 /// concrete attribute class extends PyAttribute); however, intermediate 747 /// python-visible base classes can be modeled by specifying a BaseTy. 748 template <typename DerivedTy, typename BaseTy = PyAttribute> 749 class PyConcreteAttribute : public BaseTy { 750 public: 751 // Derived classes must define statics for: 752 // IsAFunctionTy isaFunction 753 // const char *pyClassName 754 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 755 using IsAFunctionTy = bool (*)(MlirAttribute); 756 757 PyConcreteAttribute() = default; 758 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 759 : BaseTy(std::move(contextRef), attr) {} 760 PyConcreteAttribute(PyAttribute &orig) 761 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 762 763 static MlirAttribute castFrom(PyAttribute &orig) { 764 if (!DerivedTy::isaFunction(orig)) { 765 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 766 throw SetPyError(PyExc_ValueError, 767 llvm::Twine("Cannot cast attribute to ") + 768 DerivedTy::pyClassName + " (from " + origRepr + ")"); 769 } 770 return orig; 771 } 772 773 static void bind(pybind11::module &m) { 774 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 775 pybind11::module_local()); 776 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(), 777 pybind11::arg("cast_from_attr")); 778 cls.def_static( 779 "isinstance", 780 [](PyAttribute &otherAttr) -> bool { 781 return DerivedTy::isaFunction(otherAttr); 782 }, 783 pybind11::arg("other")); 784 cls.def_property_readonly("type", [](PyAttribute &attr) { 785 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 786 }); 787 DerivedTy::bindDerived(cls); 788 } 789 790 /// Implemented by derived classes to add methods to the Python subclass. 791 static void bindDerived(ClassTy &m) {} 792 }; 793 794 /// Wrapper around the generic MlirValue. 795 /// Values are managed completely by the operation that resulted in their 796 /// definition. For op result value, this is the operation that defines the 797 /// value. For block argument values, this is the operation that contains the 798 /// block to which the value is an argument (blocks cannot be detached in Python 799 /// bindings so such operation always exists). 800 class PyValue { 801 public: 802 PyValue(PyOperationRef parentOperation, MlirValue value) 803 : parentOperation(parentOperation), value(value) {} 804 operator MlirValue() const { return value; } 805 806 MlirValue get() { return value; } 807 PyOperationRef &getParentOperation() { return parentOperation; } 808 809 void checkValid() { return parentOperation->checkValid(); } 810 811 /// Gets a capsule wrapping the void* within the MlirValue. 812 pybind11::object getCapsule(); 813 814 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 815 /// the underlying MlirValue is still tied to the owning operation. 816 static PyValue createFromCapsule(pybind11::object capsule); 817 818 private: 819 PyOperationRef parentOperation; 820 MlirValue value; 821 }; 822 823 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 824 class PyAffineExpr : public BaseContextObject { 825 public: 826 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 827 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 828 bool operator==(const PyAffineExpr &other); 829 operator MlirAffineExpr() const { return affineExpr; } 830 MlirAffineExpr get() const { return affineExpr; } 831 832 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 833 pybind11::object getCapsule(); 834 835 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 836 /// Note that PyAffineExpr instances are uniqued, so the returned object 837 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 838 /// is taken by calling this function. 839 static PyAffineExpr createFromCapsule(pybind11::object capsule); 840 841 PyAffineExpr add(const PyAffineExpr &other) const; 842 PyAffineExpr mul(const PyAffineExpr &other) const; 843 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 844 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 845 PyAffineExpr mod(const PyAffineExpr &other) const; 846 847 private: 848 MlirAffineExpr affineExpr; 849 }; 850 851 class PyAffineMap : public BaseContextObject { 852 public: 853 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 854 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 855 bool operator==(const PyAffineMap &other); 856 operator MlirAffineMap() const { return affineMap; } 857 MlirAffineMap get() const { return affineMap; } 858 859 /// Gets a capsule wrapping the void* within the MlirAffineMap. 860 pybind11::object getCapsule(); 861 862 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 863 /// Note that PyAffineMap instances are uniqued, so the returned object 864 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 865 /// is taken by calling this function. 866 static PyAffineMap createFromCapsule(pybind11::object capsule); 867 868 private: 869 MlirAffineMap affineMap; 870 }; 871 872 class PyIntegerSet : public BaseContextObject { 873 public: 874 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 875 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 876 bool operator==(const PyIntegerSet &other); 877 operator MlirIntegerSet() const { return integerSet; } 878 MlirIntegerSet get() const { return integerSet; } 879 880 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 881 pybind11::object getCapsule(); 882 883 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 884 /// Note that PyIntegerSet instances may be uniqued, so the returned object 885 /// may be a pre-existing object. Integer sets are owned by the context. 886 static PyIntegerSet createFromCapsule(pybind11::object capsule); 887 888 private: 889 MlirIntegerSet integerSet; 890 }; 891 892 /// Bindings for MLIR symbol tables. 893 class PySymbolTable { 894 public: 895 /// Constructs a symbol table for the given operation. 896 explicit PySymbolTable(PyOperationBase &operation); 897 898 /// Destroys the symbol table. 899 ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } 900 901 /// Returns the symbol (opview) with the given name, throws if there is no 902 /// such symbol in the table. 903 pybind11::object dunderGetItem(const std::string &name); 904 905 /// Removes the given operation from the symbol table and erases it. 906 void erase(PyOperationBase &symbol); 907 908 /// Removes the operation with the given name from the symbol table and erases 909 /// it, throws if there is no such symbol in the table. 910 void dunderDel(const std::string &name); 911 912 /// Inserts the given operation into the symbol table. The operation must have 913 /// the symbol trait. 914 PyAttribute insert(PyOperationBase &symbol); 915 916 /// Gets and sets the name of a symbol op. 917 static PyAttribute getSymbolName(PyOperationBase &symbol); 918 static void setSymbolName(PyOperationBase &symbol, const std::string &name); 919 920 /// Gets and sets the visibility of a symbol op. 921 static PyAttribute getVisibility(PyOperationBase &symbol); 922 static void setVisibility(PyOperationBase &symbol, 923 const std::string &visibility); 924 925 /// Replaces all symbol uses within an operation. See the API 926 /// mlirSymbolTableReplaceAllSymbolUses for all caveats. 927 static void replaceAllSymbolUses(const std::string &oldSymbol, 928 const std::string &newSymbol, 929 PyOperationBase &from); 930 931 /// Walks all symbol tables under and including 'from'. 932 static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, 933 pybind11::object callback); 934 935 /// Casts the bindings class into the C API structure. 936 operator MlirSymbolTable() { return symbolTable; } 937 938 private: 939 PyOperationRef operation; 940 MlirSymbolTable symbolTable; 941 }; 942 943 void populateIRAffine(pybind11::module &m); 944 void populateIRAttributes(pybind11::module &m); 945 void populateIRCore(pybind11::module &m); 946 void populateIRInterfaces(pybind11::module &m); 947 void populateIRTypes(pybind11::module &m); 948 949 } // namespace python 950 } // namespace mlir 951 952 namespace pybind11 { 953 namespace detail { 954 955 template <> 956 struct type_caster<mlir::python::DefaultingPyMlirContext> 957 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 958 template <> 959 struct type_caster<mlir::python::DefaultingPyLocation> 960 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 961 962 } // namespace detail 963 } // namespace pybind11 964 965 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 966