1 //===- IRModules.h - IR Submodules of pybind module -----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H 10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H 11 12 #include <vector> 13 14 #include "PybindUtils.h" 15 16 #include "mlir-c/AffineExpr.h" 17 #include "mlir-c/AffineMap.h" 18 #include "mlir-c/IR.h" 19 #include "mlir-c/IntegerSet.h" 20 #include "llvm/ADT/DenseMap.h" 21 #include "llvm/ADT/Optional.h" 22 23 namespace mlir { 24 namespace python { 25 26 class PyBlock; 27 class PyInsertionPoint; 28 class PyLocation; 29 class DefaultingPyLocation; 30 class PyMlirContext; 31 class DefaultingPyMlirContext; 32 class PyModule; 33 class PyOperation; 34 class PyType; 35 class PyValue; 36 37 /// Template for a reference to a concrete type which captures a python 38 /// reference to its underlying python object. 39 template <typename T> 40 class PyObjectRef { 41 public: 42 PyObjectRef(T *referrent, pybind11::object object) 43 : referrent(referrent), object(std::move(object)) { 44 assert(this->referrent && 45 "cannot construct PyObjectRef with null referrent"); 46 assert(this->object && "cannot construct PyObjectRef with null object"); 47 } 48 PyObjectRef(PyObjectRef &&other) 49 : referrent(other.referrent), object(std::move(other.object)) { 50 other.referrent = nullptr; 51 assert(!other.object); 52 } 53 PyObjectRef(const PyObjectRef &other) 54 : referrent(other.referrent), object(other.object /* copies */) {} 55 ~PyObjectRef() {} 56 57 int getRefCount() { 58 if (!object) 59 return 0; 60 return object.ref_count(); 61 } 62 63 /// Releases the object held by this instance, returning it. 64 /// This is the proper thing to return from a function that wants to return 65 /// the reference. Note that this does not work from initializers. 66 pybind11::object releaseObject() { 67 assert(referrent && object); 68 referrent = nullptr; 69 auto stolen = std::move(object); 70 return stolen; 71 } 72 73 T *get() { return referrent; } 74 T *operator->() { 75 assert(referrent && object); 76 return referrent; 77 } 78 pybind11::object getObject() { 79 assert(referrent && object); 80 return object; 81 } 82 operator bool() const { return referrent && object; } 83 84 private: 85 T *referrent; 86 pybind11::object object; 87 }; 88 89 /// Tracks an entry in the thread context stack. New entries are pushed onto 90 /// here for each with block that activates a new InsertionPoint, Context or 91 /// Location. 92 /// 93 /// Pushing either a Location or InsertionPoint also pushes its associated 94 /// Context. Pushing a Context will not modify the Location or InsertionPoint 95 /// unless if they are from a different context, in which case, they are 96 /// cleared. 97 class PyThreadContextEntry { 98 public: 99 enum class FrameKind { 100 Context, 101 InsertionPoint, 102 Location, 103 }; 104 105 PyThreadContextEntry(FrameKind frameKind, pybind11::object context, 106 pybind11::object insertionPoint, 107 pybind11::object location) 108 : context(std::move(context)), insertionPoint(std::move(insertionPoint)), 109 location(std::move(location)), frameKind(frameKind) {} 110 111 /// Gets the top of stack context and return nullptr if not defined. 112 static PyMlirContext *getDefaultContext(); 113 114 /// Gets the top of stack insertion point and return nullptr if not defined. 115 static PyInsertionPoint *getDefaultInsertionPoint(); 116 117 /// Gets the top of stack location and returns nullptr if not defined. 118 static PyLocation *getDefaultLocation(); 119 120 PyMlirContext *getContext(); 121 PyInsertionPoint *getInsertionPoint(); 122 PyLocation *getLocation(); 123 FrameKind getFrameKind() { return frameKind; } 124 125 /// Stack management. 126 static PyThreadContextEntry *getTopOfStack(); 127 static pybind11::object pushContext(PyMlirContext &context); 128 static void popContext(PyMlirContext &context); 129 static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); 130 static void popInsertionPoint(PyInsertionPoint &insertionPoint); 131 static pybind11::object pushLocation(PyLocation &location); 132 static void popLocation(PyLocation &location); 133 134 /// Gets the thread local stack. 135 static std::vector<PyThreadContextEntry> &getStack(); 136 137 private: 138 static void push(FrameKind frameKind, pybind11::object context, 139 pybind11::object insertionPoint, pybind11::object location); 140 141 /// An object reference to the PyContext. 142 pybind11::object context; 143 /// An object reference to the current insertion point. 144 pybind11::object insertionPoint; 145 /// An object reference to the current location. 146 pybind11::object location; 147 // The kind of push that was performed. 148 FrameKind frameKind; 149 }; 150 151 /// Wrapper around MlirContext. 152 using PyMlirContextRef = PyObjectRef<PyMlirContext>; 153 class PyMlirContext { 154 public: 155 PyMlirContext() = delete; 156 PyMlirContext(const PyMlirContext &) = delete; 157 PyMlirContext(PyMlirContext &&) = delete; 158 159 /// For the case of a python __init__ (py::init) method, pybind11 is quite 160 /// strict about needing to return a pointer that is not yet associated to 161 /// an py::object. Since the forContext() method acts like a pool, possibly 162 /// returning a recycled context, it does not satisfy this need. The usual 163 /// way in python to accomplish such a thing is to override __new__, but 164 /// that is also not supported by pybind11. Instead, we use this entry 165 /// point which always constructs a fresh context (which cannot alias an 166 /// existing one because it is fresh). 167 static PyMlirContext *createNewContextForInit(); 168 169 /// Returns a context reference for the singleton PyMlirContext wrapper for 170 /// the given context. 171 static PyMlirContextRef forContext(MlirContext context); 172 ~PyMlirContext(); 173 174 /// Accesses the underlying MlirContext. 175 MlirContext get() { return context; } 176 177 /// Gets a strong reference to this context, which will ensure it is kept 178 /// alive for the life of the reference. 179 PyMlirContextRef getRef() { 180 return PyMlirContextRef(this, pybind11::cast(this)); 181 } 182 183 /// Gets a capsule wrapping the void* within the MlirContext. 184 pybind11::object getCapsule(); 185 186 /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. 187 /// Note that PyMlirContext instances are uniqued, so the returned object 188 /// may be a pre-existing object. Ownership of the underlying MlirContext 189 /// is taken by calling this function. 190 static pybind11::object createFromCapsule(pybind11::object capsule); 191 192 /// Gets the count of live context objects. Used for testing. 193 static size_t getLiveCount(); 194 195 /// Gets the count of live operations associated with this context. 196 /// Used for testing. 197 size_t getLiveOperationCount(); 198 199 /// Gets the count of live modules associated with this context. 200 /// Used for testing. 201 size_t getLiveModuleCount(); 202 203 /// Enter and exit the context manager. 204 pybind11::object contextEnter(); 205 void contextExit(pybind11::object excType, pybind11::object excVal, 206 pybind11::object excTb); 207 208 private: 209 PyMlirContext(MlirContext context); 210 // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, 211 // preserving the relationship that an MlirContext maps to a single 212 // PyMlirContext wrapper. This could be replaced in the future with an 213 // extension mechanism on the MlirContext for stashing user pointers. 214 // Note that this holds a handle, which does not imply ownership. 215 // Mappings will be removed when the context is destructed. 216 using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>; 217 static LiveContextMap &getLiveContexts(); 218 219 // Interns all live modules associated with this context. Modules tracked 220 // in this map are valid. When a module is invalidated, it is removed 221 // from this map, and while it still exists as an instance, any 222 // attempt to access it will raise an error. 223 using LiveModuleMap = 224 llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>; 225 LiveModuleMap liveModules; 226 227 // Interns all live operations associated with this context. Operations 228 // tracked in this map are valid. When an operation is invalidated, it is 229 // removed from this map, and while it still exists as an instance, any 230 // attempt to access it will raise an error. 231 using LiveOperationMap = 232 llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>; 233 LiveOperationMap liveOperations; 234 235 MlirContext context; 236 friend class PyModule; 237 friend class PyOperation; 238 }; 239 240 /// Used in function arguments when None should resolve to the current context 241 /// manager set instance. 242 class DefaultingPyMlirContext 243 : public Defaulting<DefaultingPyMlirContext, PyMlirContext> { 244 public: 245 using Defaulting::Defaulting; 246 static constexpr const char kTypeDescription[] = 247 "[ThreadContextAware] mlir.ir.Context"; 248 static PyMlirContext &resolve(); 249 }; 250 251 /// Base class for all objects that directly or indirectly depend on an 252 /// MlirContext. The lifetime of the context will extend at least to the 253 /// lifetime of these instances. 254 /// Immutable objects that depend on a context extend this directly. 255 class BaseContextObject { 256 public: 257 BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { 258 assert(this->contextRef && 259 "context object constructed with null context ref"); 260 } 261 262 /// Accesses the context reference. 263 PyMlirContextRef &getContext() { return contextRef; } 264 265 private: 266 PyMlirContextRef contextRef; 267 }; 268 269 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in 270 /// order to differentiate it from the `Dialect` base class which is extended by 271 /// plugins which extend dialect functionality through extension python code. 272 /// This should be seen as the "low-level" object and `Dialect` as the 273 /// high-level, user facing object. 274 class PyDialectDescriptor : public BaseContextObject { 275 public: 276 PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) 277 : BaseContextObject(std::move(contextRef)), dialect(dialect) {} 278 279 MlirDialect get() { return dialect; } 280 281 private: 282 MlirDialect dialect; 283 }; 284 285 /// User-level object for accessing dialects with dotted syntax such as: 286 /// ctx.dialect.std 287 class PyDialects : public BaseContextObject { 288 public: 289 PyDialects(PyMlirContextRef contextRef) 290 : BaseContextObject(std::move(contextRef)) {} 291 292 MlirDialect getDialectForKey(const std::string &key, bool attrError); 293 }; 294 295 /// User-level dialect object. For dialects that have a registered extension, 296 /// this will be the base class of the extension dialect type. For un-extended, 297 /// objects of this type will be returned directly. 298 class PyDialect { 299 public: 300 PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} 301 302 pybind11::object getDescriptor() { return descriptor; } 303 304 private: 305 pybind11::object descriptor; 306 }; 307 308 /// Wrapper around an MlirLocation. 309 class PyLocation : public BaseContextObject { 310 public: 311 PyLocation(PyMlirContextRef contextRef, MlirLocation loc) 312 : BaseContextObject(std::move(contextRef)), loc(loc) {} 313 314 operator MlirLocation() const { return loc; } 315 MlirLocation get() const { return loc; } 316 317 /// Enter and exit the context manager. 318 pybind11::object contextEnter(); 319 void contextExit(pybind11::object excType, pybind11::object excVal, 320 pybind11::object excTb); 321 322 /// Gets a capsule wrapping the void* within the MlirLocation. 323 pybind11::object getCapsule(); 324 325 /// Creates a PyLocation from the MlirLocation wrapped by a capsule. 326 /// Note that PyLocation instances are uniqued, so the returned object 327 /// may be a pre-existing object. Ownership of the underlying MlirLocation 328 /// is taken by calling this function. 329 static PyLocation createFromCapsule(pybind11::object capsule); 330 331 private: 332 MlirLocation loc; 333 }; 334 335 /// Used in function arguments when None should resolve to the current context 336 /// manager set instance. 337 class DefaultingPyLocation 338 : public Defaulting<DefaultingPyLocation, PyLocation> { 339 public: 340 using Defaulting::Defaulting; 341 static constexpr const char kTypeDescription[] = 342 "[ThreadContextAware] mlir.ir.Location"; 343 static PyLocation &resolve(); 344 345 operator MlirLocation() const { return *get(); } 346 }; 347 348 /// Wrapper around MlirModule. 349 /// This is the top-level, user-owned object that contains regions/ops/blocks. 350 class PyModule; 351 using PyModuleRef = PyObjectRef<PyModule>; 352 class PyModule : public BaseContextObject { 353 public: 354 /// Returns a PyModule reference for the given MlirModule. This may return 355 /// a pre-existing or new object. 356 static PyModuleRef forModule(MlirModule module); 357 PyModule(PyModule &) = delete; 358 PyModule(PyMlirContext &&) = delete; 359 ~PyModule(); 360 361 /// Gets the backing MlirModule. 362 MlirModule get() { return module; } 363 364 /// Gets a strong reference to this module. 365 PyModuleRef getRef() { 366 return PyModuleRef(this, 367 pybind11::reinterpret_borrow<pybind11::object>(handle)); 368 } 369 370 /// Gets a capsule wrapping the void* within the MlirModule. 371 /// Note that the module does not (yet) provide a corresponding factory for 372 /// constructing from a capsule as that would require uniquing PyModule 373 /// instances, which is not currently done. 374 pybind11::object getCapsule(); 375 376 /// Creates a PyModule from the MlirModule wrapped by a capsule. 377 /// Note that PyModule instances are uniqued, so the returned object 378 /// may be a pre-existing object. Ownership of the underlying MlirModule 379 /// is taken by calling this function. 380 static pybind11::object createFromCapsule(pybind11::object capsule); 381 382 private: 383 PyModule(PyMlirContextRef contextRef, MlirModule module); 384 MlirModule module; 385 pybind11::handle handle; 386 }; 387 388 /// Base class for PyOperation and PyOpView which exposes the primary, user 389 /// visible methods for manipulating it. 390 class PyOperationBase { 391 public: 392 virtual ~PyOperationBase() = default; 393 /// Implements the bound 'print' method and helps with others. 394 void print(pybind11::object fileObject, bool binary, 395 llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo, 396 bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); 397 pybind11::object getAsm(bool binary, 398 llvm::Optional<int64_t> largeElementsLimit, 399 bool enableDebugInfo, bool prettyDebugInfo, 400 bool printGenericOpForm, bool useLocalScope); 401 402 /// Moves the operation before or after the other operation. 403 void moveAfter(PyOperationBase &other); 404 void moveBefore(PyOperationBase &other); 405 406 /// Each must provide access to the raw Operation. 407 virtual PyOperation &getOperation() = 0; 408 }; 409 410 /// Wrapper around PyOperation. 411 /// Operations exist in either an attached (dependent) or detached (top-level) 412 /// state. In the detached state (as on creation), an operation is owned by 413 /// the creator and its lifetime extends either until its reference count 414 /// drops to zero or it is attached to a parent, at which point its lifetime 415 /// is bounded by its top-level parent reference. 416 class PyOperation; 417 using PyOperationRef = PyObjectRef<PyOperation>; 418 class PyOperation : public PyOperationBase, public BaseContextObject { 419 public: 420 ~PyOperation(); 421 PyOperation &getOperation() override { return *this; } 422 423 /// Returns a PyOperation for the given MlirOperation, optionally associating 424 /// it with a parentKeepAlive. 425 static PyOperationRef 426 forOperation(PyMlirContextRef contextRef, MlirOperation operation, 427 pybind11::object parentKeepAlive = pybind11::object()); 428 429 /// Creates a detached operation. The operation must not be associated with 430 /// any existing live operation. 431 static PyOperationRef 432 createDetached(PyMlirContextRef contextRef, MlirOperation operation, 433 pybind11::object parentKeepAlive = pybind11::object()); 434 435 /// Detaches the operation from its parent block and updates its state 436 /// accordingly. 437 void detachFromParent() { 438 mlirOperationRemoveFromParent(getOperation()); 439 setDetached(); 440 parentKeepAlive = pybind11::object(); 441 } 442 443 /// Gets the backing operation. 444 operator MlirOperation() const { return get(); } 445 MlirOperation get() const { 446 checkValid(); 447 return operation; 448 } 449 450 PyOperationRef getRef() { 451 return PyOperationRef( 452 this, pybind11::reinterpret_borrow<pybind11::object>(handle)); 453 } 454 455 bool isAttached() { return attached; } 456 void setAttached(pybind11::object parent = pybind11::object()) { 457 assert(!attached && "operation already attached"); 458 attached = true; 459 } 460 void setDetached() { 461 assert(attached && "operation already detached"); 462 attached = false; 463 } 464 void checkValid() const; 465 466 /// Gets the owning block or raises an exception if the operation has no 467 /// owning block. 468 PyBlock getBlock(); 469 470 /// Gets the parent operation or raises an exception if the operation has 471 /// no parent. 472 llvm::Optional<PyOperationRef> getParentOperation(); 473 474 /// Gets a capsule wrapping the void* within the MlirOperation. 475 pybind11::object getCapsule(); 476 477 /// Creates a PyOperation from the MlirOperation wrapped by a capsule. 478 /// Ownership of the underlying MlirOperation is taken by calling this 479 /// function. 480 static pybind11::object createFromCapsule(pybind11::object capsule); 481 482 /// Creates an operation. See corresponding python docstring. 483 static pybind11::object 484 create(std::string name, llvm::Optional<std::vector<PyType *>> results, 485 llvm::Optional<std::vector<PyValue *>> operands, 486 llvm::Optional<pybind11::dict> attributes, 487 llvm::Optional<std::vector<PyBlock *>> successors, int regions, 488 DefaultingPyLocation location, pybind11::object ip); 489 490 /// Creates an OpView suitable for this operation. 491 pybind11::object createOpView(); 492 493 /// Erases the underlying MlirOperation, removes its pointer from the 494 /// parent context's live operations map, and sets the valid bit false. 495 void erase(); 496 497 private: 498 PyOperation(PyMlirContextRef contextRef, MlirOperation operation); 499 static PyOperationRef createInstance(PyMlirContextRef contextRef, 500 MlirOperation operation, 501 pybind11::object parentKeepAlive); 502 503 MlirOperation operation; 504 pybind11::handle handle; 505 // Keeps the parent alive, regardless of whether it is an Operation or 506 // Module. 507 // TODO: As implemented, this facility is only sufficient for modeling the 508 // trivial module parent back-reference. Generalize this to also account for 509 // transitions from detached to attached and address TODOs in the 510 // ir_operation.py regarding testing corresponding lifetime guarantees. 511 pybind11::object parentKeepAlive; 512 bool attached = true; 513 bool valid = true; 514 515 friend class PyOperationBase; 516 }; 517 518 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for 519 /// providing more instance-specific accessors and serve as the base class for 520 /// custom ODS-style operation classes. Since this class is subclass on the 521 /// python side, it must present an __init__ method that operates in pure 522 /// python types. 523 class PyOpView : public PyOperationBase { 524 public: 525 PyOpView(pybind11::object operationObject); 526 PyOperation &getOperation() override { return operation; } 527 528 static pybind11::object createRawSubclass(pybind11::object userClass); 529 530 pybind11::object getOperationObject() { return operationObject; } 531 532 static pybind11::object 533 buildGeneric(pybind11::object cls, pybind11::list resultTypeList, 534 pybind11::list operandList, 535 llvm::Optional<pybind11::dict> attributes, 536 llvm::Optional<std::vector<PyBlock *>> successors, 537 llvm::Optional<int> regions, DefaultingPyLocation location, 538 pybind11::object maybeIp); 539 540 private: 541 PyOperation &operation; // For efficient, cast-free access from C++ 542 pybind11::object operationObject; // Holds the reference. 543 }; 544 545 /// Wrapper around an MlirRegion. 546 /// Regions are managed completely by their containing operation. Unlike the 547 /// C++ API, the python API does not support detached regions. 548 class PyRegion { 549 public: 550 PyRegion(PyOperationRef parentOperation, MlirRegion region) 551 : parentOperation(std::move(parentOperation)), region(region) { 552 assert(!mlirRegionIsNull(region) && "python region cannot be null"); 553 } 554 operator MlirRegion() const { return region; } 555 556 MlirRegion get() { return region; } 557 PyOperationRef &getParentOperation() { return parentOperation; } 558 559 void checkValid() { return parentOperation->checkValid(); } 560 561 private: 562 PyOperationRef parentOperation; 563 MlirRegion region; 564 }; 565 566 /// Wrapper around an MlirBlock. 567 /// Blocks are managed completely by their containing operation. Unlike the 568 /// C++ API, the python API does not support detached blocks. 569 class PyBlock { 570 public: 571 PyBlock(PyOperationRef parentOperation, MlirBlock block) 572 : parentOperation(std::move(parentOperation)), block(block) { 573 assert(!mlirBlockIsNull(block) && "python block cannot be null"); 574 } 575 576 MlirBlock get() { return block; } 577 PyOperationRef &getParentOperation() { return parentOperation; } 578 579 void checkValid() { return parentOperation->checkValid(); } 580 581 private: 582 PyOperationRef parentOperation; 583 MlirBlock block; 584 }; 585 586 /// An insertion point maintains a pointer to a Block and a reference operation. 587 /// Calls to insert() will insert a new operation before the 588 /// reference operation. If the reference operation is null, then appends to 589 /// the end of the block. 590 class PyInsertionPoint { 591 public: 592 /// Creates an insertion point positioned after the last operation in the 593 /// block, but still inside the block. 594 PyInsertionPoint(PyBlock &block); 595 /// Creates an insertion point positioned before a reference operation. 596 PyInsertionPoint(PyOperationBase &beforeOperationBase); 597 598 /// Shortcut to create an insertion point at the beginning of the block. 599 static PyInsertionPoint atBlockBegin(PyBlock &block); 600 /// Shortcut to create an insertion point before the block terminator. 601 static PyInsertionPoint atBlockTerminator(PyBlock &block); 602 603 /// Inserts an operation. 604 void insert(PyOperationBase &operationBase); 605 606 /// Enter and exit the context manager. 607 pybind11::object contextEnter(); 608 void contextExit(pybind11::object excType, pybind11::object excVal, 609 pybind11::object excTb); 610 611 PyBlock &getBlock() { return block; } 612 613 private: 614 // Trampoline constructor that avoids null initializing members while 615 // looking up parents. 616 PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation) 617 : refOperation(std::move(refOperation)), block(std::move(block)) {} 618 619 llvm::Optional<PyOperationRef> refOperation; 620 PyBlock block; 621 }; 622 /// Wrapper around the generic MlirType. 623 /// The lifetime of a type is bound by the PyContext that created it. 624 class PyType : public BaseContextObject { 625 public: 626 PyType(PyMlirContextRef contextRef, MlirType type) 627 : BaseContextObject(std::move(contextRef)), type(type) {} 628 bool operator==(const PyType &other); 629 operator MlirType() const { return type; } 630 MlirType get() const { return type; } 631 632 /// Gets a capsule wrapping the void* within the MlirType. 633 pybind11::object getCapsule(); 634 635 /// Creates a PyType from the MlirType wrapped by a capsule. 636 /// Note that PyType instances are uniqued, so the returned object 637 /// may be a pre-existing object. Ownership of the underlying MlirType 638 /// is taken by calling this function. 639 static PyType createFromCapsule(pybind11::object capsule); 640 641 private: 642 MlirType type; 643 }; 644 645 /// CRTP base classes for Python types that subclass Type and should be 646 /// castable from it (i.e. via something like IntegerType(t)). 647 /// By default, type class hierarchies are one level deep (i.e. a 648 /// concrete type class extends PyType); however, intermediate python-visible 649 /// base classes can be modeled by specifying a BaseTy. 650 template <typename DerivedTy, typename BaseTy = PyType> 651 class PyConcreteType : public BaseTy { 652 public: 653 // Derived classes must define statics for: 654 // IsAFunctionTy isaFunction 655 // const char *pyClassName 656 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 657 using IsAFunctionTy = bool (*)(MlirType); 658 659 PyConcreteType() = default; 660 PyConcreteType(PyMlirContextRef contextRef, MlirType t) 661 : BaseTy(std::move(contextRef), t) {} 662 PyConcreteType(PyType &orig) 663 : PyConcreteType(orig.getContext(), castFrom(orig)) {} 664 665 static MlirType castFrom(PyType &orig) { 666 if (!DerivedTy::isaFunction(orig)) { 667 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 668 throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + 669 DerivedTy::pyClassName + 670 " (from " + origRepr + ")"); 671 } 672 return orig; 673 } 674 675 static void bind(pybind11::module &m) { 676 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); 677 cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>()); 678 cls.def_static("isinstance", [](PyType &otherType) -> bool { 679 return DerivedTy::isaFunction(otherType); 680 }); 681 DerivedTy::bindDerived(cls); 682 } 683 684 /// Implemented by derived classes to add methods to the Python subclass. 685 static void bindDerived(ClassTy &m) {} 686 }; 687 688 /// Wrapper around the generic MlirAttribute. 689 /// The lifetime of a type is bound by the PyContext that created it. 690 class PyAttribute : public BaseContextObject { 691 public: 692 PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 693 : BaseContextObject(std::move(contextRef)), attr(attr) {} 694 bool operator==(const PyAttribute &other); 695 operator MlirAttribute() const { return attr; } 696 MlirAttribute get() const { return attr; } 697 698 /// Gets a capsule wrapping the void* within the MlirAttribute. 699 pybind11::object getCapsule(); 700 701 /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. 702 /// Note that PyAttribute instances are uniqued, so the returned object 703 /// may be a pre-existing object. Ownership of the underlying MlirAttribute 704 /// is taken by calling this function. 705 static PyAttribute createFromCapsule(pybind11::object capsule); 706 707 private: 708 MlirAttribute attr; 709 }; 710 711 /// Represents a Python MlirNamedAttr, carrying an optional owned name. 712 /// TODO: Refactor this and the C-API to be based on an Identifier owned 713 /// by the context so as to avoid ownership issues here. 714 class PyNamedAttribute { 715 public: 716 /// Constructs a PyNamedAttr that retains an owned name. This should be 717 /// used in any code that originates an MlirNamedAttribute from a python 718 /// string. 719 /// The lifetime of the PyNamedAttr must extend to the lifetime of the 720 /// passed attribute. 721 PyNamedAttribute(MlirAttribute attr, std::string ownedName); 722 723 MlirNamedAttribute namedAttr; 724 725 private: 726 // Since the MlirNamedAttr contains an internal pointer to the actual 727 // memory of the owned string, it must be heap allocated to remain valid. 728 // Otherwise, strings that fit within the small object optimization threshold 729 // will have their memory address change as the containing object is moved, 730 // resulting in an invalid aliased pointer. 731 std::unique_ptr<std::string> ownedName; 732 }; 733 734 /// CRTP base classes for Python attributes that subclass Attribute and should 735 /// be castable from it (i.e. via something like StringAttr(attr)). 736 /// By default, attribute class hierarchies are one level deep (i.e. a 737 /// concrete attribute class extends PyAttribute); however, intermediate 738 /// python-visible base classes can be modeled by specifying a BaseTy. 739 template <typename DerivedTy, typename BaseTy = PyAttribute> 740 class PyConcreteAttribute : public BaseTy { 741 public: 742 // Derived classes must define statics for: 743 // IsAFunctionTy isaFunction 744 // const char *pyClassName 745 using ClassTy = pybind11::class_<DerivedTy, BaseTy>; 746 using IsAFunctionTy = bool (*)(MlirAttribute); 747 748 PyConcreteAttribute() = default; 749 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) 750 : BaseTy(std::move(contextRef), attr) {} 751 PyConcreteAttribute(PyAttribute &orig) 752 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} 753 754 static MlirAttribute castFrom(PyAttribute &orig) { 755 if (!DerivedTy::isaFunction(orig)) { 756 auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>(); 757 throw SetPyError(PyExc_ValueError, 758 llvm::Twine("Cannot cast attribute to ") + 759 DerivedTy::pyClassName + " (from " + origRepr + ")"); 760 } 761 return orig; 762 } 763 764 static void bind(pybind11::module &m) { 765 auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), 766 pybind11::module_local()); 767 cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>()); 768 cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { 769 return DerivedTy::isaFunction(otherAttr); 770 }); 771 cls.def_property_readonly("type", [](PyAttribute &attr) { 772 return PyType(attr.getContext(), mlirAttributeGetType(attr)); 773 }); 774 DerivedTy::bindDerived(cls); 775 } 776 777 /// Implemented by derived classes to add methods to the Python subclass. 778 static void bindDerived(ClassTy &m) {} 779 }; 780 781 /// Wrapper around the generic MlirValue. 782 /// Values are managed completely by the operation that resulted in their 783 /// definition. For op result value, this is the operation that defines the 784 /// value. For block argument values, this is the operation that contains the 785 /// block to which the value is an argument (blocks cannot be detached in Python 786 /// bindings so such operation always exists). 787 class PyValue { 788 public: 789 PyValue(PyOperationRef parentOperation, MlirValue value) 790 : parentOperation(parentOperation), value(value) {} 791 operator MlirValue() const { return value; } 792 793 MlirValue get() { return value; } 794 PyOperationRef &getParentOperation() { return parentOperation; } 795 796 void checkValid() { return parentOperation->checkValid(); } 797 798 /// Gets a capsule wrapping the void* within the MlirValue. 799 pybind11::object getCapsule(); 800 801 /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of 802 /// the underlying MlirValue is still tied to the owning operation. 803 static PyValue createFromCapsule(pybind11::object capsule); 804 805 private: 806 PyOperationRef parentOperation; 807 MlirValue value; 808 }; 809 810 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. 811 class PyAffineExpr : public BaseContextObject { 812 public: 813 PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) 814 : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} 815 bool operator==(const PyAffineExpr &other); 816 operator MlirAffineExpr() const { return affineExpr; } 817 MlirAffineExpr get() const { return affineExpr; } 818 819 /// Gets a capsule wrapping the void* within the MlirAffineExpr. 820 pybind11::object getCapsule(); 821 822 /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. 823 /// Note that PyAffineExpr instances are uniqued, so the returned object 824 /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr 825 /// is taken by calling this function. 826 static PyAffineExpr createFromCapsule(pybind11::object capsule); 827 828 PyAffineExpr add(const PyAffineExpr &other) const; 829 PyAffineExpr mul(const PyAffineExpr &other) const; 830 PyAffineExpr floorDiv(const PyAffineExpr &other) const; 831 PyAffineExpr ceilDiv(const PyAffineExpr &other) const; 832 PyAffineExpr mod(const PyAffineExpr &other) const; 833 834 private: 835 MlirAffineExpr affineExpr; 836 }; 837 838 class PyAffineMap : public BaseContextObject { 839 public: 840 PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) 841 : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} 842 bool operator==(const PyAffineMap &other); 843 operator MlirAffineMap() const { return affineMap; } 844 MlirAffineMap get() const { return affineMap; } 845 846 /// Gets a capsule wrapping the void* within the MlirAffineMap. 847 pybind11::object getCapsule(); 848 849 /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. 850 /// Note that PyAffineMap instances are uniqued, so the returned object 851 /// may be a pre-existing object. Ownership of the underlying MlirAffineMap 852 /// is taken by calling this function. 853 static PyAffineMap createFromCapsule(pybind11::object capsule); 854 855 private: 856 MlirAffineMap affineMap; 857 }; 858 859 class PyIntegerSet : public BaseContextObject { 860 public: 861 PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) 862 : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} 863 bool operator==(const PyIntegerSet &other); 864 operator MlirIntegerSet() const { return integerSet; } 865 MlirIntegerSet get() const { return integerSet; } 866 867 /// Gets a capsule wrapping the void* within the MlirIntegerSet. 868 pybind11::object getCapsule(); 869 870 /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. 871 /// Note that PyIntegerSet instances may be uniqued, so the returned object 872 /// may be a pre-existing object. Integer sets are owned by the context. 873 static PyIntegerSet createFromCapsule(pybind11::object capsule); 874 875 private: 876 MlirIntegerSet integerSet; 877 }; 878 879 void populateIRAffine(pybind11::module &m); 880 void populateIRAttributes(pybind11::module &m); 881 void populateIRCore(pybind11::module &m); 882 void populateIRInterfaces(pybind11::module &m); 883 void populateIRTypes(pybind11::module &m); 884 885 } // namespace python 886 } // namespace mlir 887 888 namespace pybind11 { 889 namespace detail { 890 891 template <> 892 struct type_caster<mlir::python::DefaultingPyMlirContext> 893 : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {}; 894 template <> 895 struct type_caster<mlir::python::DefaultingPyLocation> 896 : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {}; 897 898 } // namespace detail 899 } // namespace pybind11 900 901 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H 902