1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <vector>
13 
14 #include "PybindUtils.h"
15 
16 #include "mlir-c/AffineExpr.h"
17 #include "mlir-c/AffineMap.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/IntegerSet.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Optional.h"
22 
23 namespace mlir {
24 namespace python {
25 
26 class PyBlock;
27 class PyInsertionPoint;
28 class PyLocation;
29 class DefaultingPyLocation;
30 class PyMlirContext;
31 class DefaultingPyMlirContext;
32 class PyModule;
33 class PyOperation;
34 class PyType;
35 class PySymbolTable;
36 class PyValue;
37 
38 /// Template for a reference to a concrete type which captures a python
39 /// reference to its underlying python object.
40 template <typename T>
41 class PyObjectRef {
42 public:
43   PyObjectRef(T *referrent, pybind11::object object)
44       : referrent(referrent), object(std::move(object)) {
45     assert(this->referrent &&
46            "cannot construct PyObjectRef with null referrent");
47     assert(this->object && "cannot construct PyObjectRef with null object");
48   }
49   PyObjectRef(PyObjectRef &&other)
50       : referrent(other.referrent), object(std::move(other.object)) {
51     other.referrent = nullptr;
52     assert(!other.object);
53   }
54   PyObjectRef(const PyObjectRef &other)
55       : referrent(other.referrent), object(other.object /* copies */) {}
56   ~PyObjectRef() {}
57 
58   int getRefCount() {
59     if (!object)
60       return 0;
61     return object.ref_count();
62   }
63 
64   /// Releases the object held by this instance, returning it.
65   /// This is the proper thing to return from a function that wants to return
66   /// the reference. Note that this does not work from initializers.
67   pybind11::object releaseObject() {
68     assert(referrent && object);
69     referrent = nullptr;
70     auto stolen = std::move(object);
71     return stolen;
72   }
73 
74   T *get() { return referrent; }
75   T *operator->() {
76     assert(referrent && object);
77     return referrent;
78   }
79   pybind11::object getObject() {
80     assert(referrent && object);
81     return object;
82   }
83   operator bool() const { return referrent && object; }
84 
85 private:
86   T *referrent;
87   pybind11::object object;
88 };
89 
90 /// Tracks an entry in the thread context stack. New entries are pushed onto
91 /// here for each with block that activates a new InsertionPoint, Context or
92 /// Location.
93 ///
94 /// Pushing either a Location or InsertionPoint also pushes its associated
95 /// Context. Pushing a Context will not modify the Location or InsertionPoint
96 /// unless if they are from a different context, in which case, they are
97 /// cleared.
98 class PyThreadContextEntry {
99 public:
100   enum class FrameKind {
101     Context,
102     InsertionPoint,
103     Location,
104   };
105 
106   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
107                        pybind11::object insertionPoint,
108                        pybind11::object location)
109       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
110         location(std::move(location)), frameKind(frameKind) {}
111 
112   /// Gets the top of stack context and return nullptr if not defined.
113   static PyMlirContext *getDefaultContext();
114 
115   /// Gets the top of stack insertion point and return nullptr if not defined.
116   static PyInsertionPoint *getDefaultInsertionPoint();
117 
118   /// Gets the top of stack location and returns nullptr if not defined.
119   static PyLocation *getDefaultLocation();
120 
121   PyMlirContext *getContext();
122   PyInsertionPoint *getInsertionPoint();
123   PyLocation *getLocation();
124   FrameKind getFrameKind() { return frameKind; }
125 
126   /// Stack management.
127   static PyThreadContextEntry *getTopOfStack();
128   static pybind11::object pushContext(PyMlirContext &context);
129   static void popContext(PyMlirContext &context);
130   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
131   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
132   static pybind11::object pushLocation(PyLocation &location);
133   static void popLocation(PyLocation &location);
134 
135   /// Gets the thread local stack.
136   static std::vector<PyThreadContextEntry> &getStack();
137 
138 private:
139   static void push(FrameKind frameKind, pybind11::object context,
140                    pybind11::object insertionPoint, pybind11::object location);
141 
142   /// An object reference to the PyContext.
143   pybind11::object context;
144   /// An object reference to the current insertion point.
145   pybind11::object insertionPoint;
146   /// An object reference to the current location.
147   pybind11::object location;
148   // The kind of push that was performed.
149   FrameKind frameKind;
150 };
151 
152 /// Wrapper around MlirContext.
153 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
154 class PyMlirContext {
155 public:
156   PyMlirContext() = delete;
157   PyMlirContext(const PyMlirContext &) = delete;
158   PyMlirContext(PyMlirContext &&) = delete;
159 
160   /// For the case of a python __init__ (py::init) method, pybind11 is quite
161   /// strict about needing to return a pointer that is not yet associated to
162   /// an py::object. Since the forContext() method acts like a pool, possibly
163   /// returning a recycled context, it does not satisfy this need. The usual
164   /// way in python to accomplish such a thing is to override __new__, but
165   /// that is also not supported by pybind11. Instead, we use this entry
166   /// point which always constructs a fresh context (which cannot alias an
167   /// existing one because it is fresh).
168   static PyMlirContext *createNewContextForInit();
169 
170   /// Returns a context reference for the singleton PyMlirContext wrapper for
171   /// the given context.
172   static PyMlirContextRef forContext(MlirContext context);
173   ~PyMlirContext();
174 
175   /// Accesses the underlying MlirContext.
176   MlirContext get() { return context; }
177 
178   /// Gets a strong reference to this context, which will ensure it is kept
179   /// alive for the life of the reference.
180   PyMlirContextRef getRef() {
181     return PyMlirContextRef(this, pybind11::cast(this));
182   }
183 
184   /// Gets a capsule wrapping the void* within the MlirContext.
185   pybind11::object getCapsule();
186 
187   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
188   /// Note that PyMlirContext instances are uniqued, so the returned object
189   /// may be a pre-existing object. Ownership of the underlying MlirContext
190   /// is taken by calling this function.
191   static pybind11::object createFromCapsule(pybind11::object capsule);
192 
193   /// Gets the count of live context objects. Used for testing.
194   static size_t getLiveCount();
195 
196   /// Gets the count of live operations associated with this context.
197   /// Used for testing.
198   size_t getLiveOperationCount();
199 
200   /// Gets the count of live modules associated with this context.
201   /// Used for testing.
202   size_t getLiveModuleCount();
203 
204   /// Enter and exit the context manager.
205   pybind11::object contextEnter();
206   void contextExit(const pybind11::object &excType,
207                    const pybind11::object &excVal,
208                    const pybind11::object &excTb);
209 
210 private:
211   PyMlirContext(MlirContext context);
212   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
213   // preserving the relationship that an MlirContext maps to a single
214   // PyMlirContext wrapper. This could be replaced in the future with an
215   // extension mechanism on the MlirContext for stashing user pointers.
216   // Note that this holds a handle, which does not imply ownership.
217   // Mappings will be removed when the context is destructed.
218   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
219   static LiveContextMap &getLiveContexts();
220 
221   // Interns all live modules associated with this context. Modules tracked
222   // in this map are valid. When a module is invalidated, it is removed
223   // from this map, and while it still exists as an instance, any
224   // attempt to access it will raise an error.
225   using LiveModuleMap =
226       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
227   LiveModuleMap liveModules;
228 
229   // Interns all live operations associated with this context. Operations
230   // tracked in this map are valid. When an operation is invalidated, it is
231   // removed from this map, and while it still exists as an instance, any
232   // attempt to access it will raise an error.
233   using LiveOperationMap =
234       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
235   LiveOperationMap liveOperations;
236 
237   MlirContext context;
238   friend class PyModule;
239   friend class PyOperation;
240 };
241 
242 /// Used in function arguments when None should resolve to the current context
243 /// manager set instance.
244 class DefaultingPyMlirContext
245     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
246 public:
247   using Defaulting::Defaulting;
248   static constexpr const char kTypeDescription[] = "mlir.ir.Context";
249   static PyMlirContext &resolve();
250 };
251 
252 /// Base class for all objects that directly or indirectly depend on an
253 /// MlirContext. The lifetime of the context will extend at least to the
254 /// lifetime of these instances.
255 /// Immutable objects that depend on a context extend this directly.
256 class BaseContextObject {
257 public:
258   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
259     assert(this->contextRef &&
260            "context object constructed with null context ref");
261   }
262 
263   /// Accesses the context reference.
264   PyMlirContextRef &getContext() { return contextRef; }
265 
266 private:
267   PyMlirContextRef contextRef;
268 };
269 
270 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
271 /// order to differentiate it from the `Dialect` base class which is extended by
272 /// plugins which extend dialect functionality through extension python code.
273 /// This should be seen as the "low-level" object and `Dialect` as the
274 /// high-level, user facing object.
275 class PyDialectDescriptor : public BaseContextObject {
276 public:
277   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
278       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
279 
280   MlirDialect get() { return dialect; }
281 
282 private:
283   MlirDialect dialect;
284 };
285 
286 /// User-level object for accessing dialects with dotted syntax such as:
287 ///   ctx.dialect.std
288 class PyDialects : public BaseContextObject {
289 public:
290   PyDialects(PyMlirContextRef contextRef)
291       : BaseContextObject(std::move(contextRef)) {}
292 
293   MlirDialect getDialectForKey(const std::string &key, bool attrError);
294 };
295 
296 /// User-level dialect object. For dialects that have a registered extension,
297 /// this will be the base class of the extension dialect type. For un-extended,
298 /// objects of this type will be returned directly.
299 class PyDialect {
300 public:
301   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
302 
303   pybind11::object getDescriptor() { return descriptor; }
304 
305 private:
306   pybind11::object descriptor;
307 };
308 
309 /// Wrapper around an MlirLocation.
310 class PyLocation : public BaseContextObject {
311 public:
312   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
313       : BaseContextObject(std::move(contextRef)), loc(loc) {}
314 
315   operator MlirLocation() const { return loc; }
316   MlirLocation get() const { return loc; }
317 
318   /// Enter and exit the context manager.
319   pybind11::object contextEnter();
320   void contextExit(const pybind11::object &excType,
321                    const pybind11::object &excVal,
322                    const pybind11::object &excTb);
323 
324   /// Gets a capsule wrapping the void* within the MlirLocation.
325   pybind11::object getCapsule();
326 
327   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
328   /// Note that PyLocation instances are uniqued, so the returned object
329   /// may be a pre-existing object. Ownership of the underlying MlirLocation
330   /// is taken by calling this function.
331   static PyLocation createFromCapsule(pybind11::object capsule);
332 
333 private:
334   MlirLocation loc;
335 };
336 
337 /// Used in function arguments when None should resolve to the current context
338 /// manager set instance.
339 class DefaultingPyLocation
340     : public Defaulting<DefaultingPyLocation, PyLocation> {
341 public:
342   using Defaulting::Defaulting;
343   static constexpr const char kTypeDescription[] = "mlir.ir.Location";
344   static PyLocation &resolve();
345 
346   operator MlirLocation() const { return *get(); }
347 };
348 
349 /// Wrapper around MlirModule.
350 /// This is the top-level, user-owned object that contains regions/ops/blocks.
351 class PyModule;
352 using PyModuleRef = PyObjectRef<PyModule>;
353 class PyModule : public BaseContextObject {
354 public:
355   /// Returns a PyModule reference for the given MlirModule. This may return
356   /// a pre-existing or new object.
357   static PyModuleRef forModule(MlirModule module);
358   PyModule(PyModule &) = delete;
359   PyModule(PyMlirContext &&) = delete;
360   ~PyModule();
361 
362   /// Gets the backing MlirModule.
363   MlirModule get() { return module; }
364 
365   /// Gets a strong reference to this module.
366   PyModuleRef getRef() {
367     return PyModuleRef(this,
368                        pybind11::reinterpret_borrow<pybind11::object>(handle));
369   }
370 
371   /// Gets a capsule wrapping the void* within the MlirModule.
372   /// Note that the module does not (yet) provide a corresponding factory for
373   /// constructing from a capsule as that would require uniquing PyModule
374   /// instances, which is not currently done.
375   pybind11::object getCapsule();
376 
377   /// Creates a PyModule from the MlirModule wrapped by a capsule.
378   /// Note that PyModule instances are uniqued, so the returned object
379   /// may be a pre-existing object. Ownership of the underlying MlirModule
380   /// is taken by calling this function.
381   static pybind11::object createFromCapsule(pybind11::object capsule);
382 
383 private:
384   PyModule(PyMlirContextRef contextRef, MlirModule module);
385   MlirModule module;
386   pybind11::handle handle;
387 };
388 
389 /// Base class for PyOperation and PyOpView which exposes the primary, user
390 /// visible methods for manipulating it.
391 class PyOperationBase {
392 public:
393   virtual ~PyOperationBase() = default;
394   /// Implements the bound 'print' method and helps with others.
395   void print(pybind11::object fileObject, bool binary,
396              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
397              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
398              bool assumeVerified);
399   pybind11::object getAsm(bool binary,
400                           llvm::Optional<int64_t> largeElementsLimit,
401                           bool enableDebugInfo, bool prettyDebugInfo,
402                           bool printGenericOpForm, bool useLocalScope,
403                           bool assumeVerified);
404 
405   /// Moves the operation before or after the other operation.
406   void moveAfter(PyOperationBase &other);
407   void moveBefore(PyOperationBase &other);
408 
409   /// Each must provide access to the raw Operation.
410   virtual PyOperation &getOperation() = 0;
411 };
412 
413 /// Wrapper around PyOperation.
414 /// Operations exist in either an attached (dependent) or detached (top-level)
415 /// state. In the detached state (as on creation), an operation is owned by
416 /// the creator and its lifetime extends either until its reference count
417 /// drops to zero or it is attached to a parent, at which point its lifetime
418 /// is bounded by its top-level parent reference.
419 class PyOperation;
420 using PyOperationRef = PyObjectRef<PyOperation>;
421 class PyOperation : public PyOperationBase, public BaseContextObject {
422 public:
423   ~PyOperation();
424   PyOperation &getOperation() override { return *this; }
425 
426   /// Returns a PyOperation for the given MlirOperation, optionally associating
427   /// it with a parentKeepAlive.
428   static PyOperationRef
429   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
430                pybind11::object parentKeepAlive = pybind11::object());
431 
432   /// Creates a detached operation. The operation must not be associated with
433   /// any existing live operation.
434   static PyOperationRef
435   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
436                  pybind11::object parentKeepAlive = pybind11::object());
437 
438   /// Detaches the operation from its parent block and updates its state
439   /// accordingly.
440   void detachFromParent() {
441     mlirOperationRemoveFromParent(getOperation());
442     setDetached();
443     parentKeepAlive = pybind11::object();
444   }
445 
446   /// Gets the backing operation.
447   operator MlirOperation() const { return get(); }
448   MlirOperation get() const {
449     checkValid();
450     return operation;
451   }
452 
453   PyOperationRef getRef() {
454     return PyOperationRef(
455         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
456   }
457 
458   bool isAttached() { return attached; }
459   void setAttached(pybind11::object parent = pybind11::object()) {
460     assert(!attached && "operation already attached");
461     attached = true;
462   }
463   void setDetached() {
464     assert(attached && "operation already detached");
465     attached = false;
466   }
467   void checkValid() const;
468 
469   /// Gets the owning block or raises an exception if the operation has no
470   /// owning block.
471   PyBlock getBlock();
472 
473   /// Gets the parent operation or raises an exception if the operation has
474   /// no parent.
475   llvm::Optional<PyOperationRef> getParentOperation();
476 
477   /// Gets a capsule wrapping the void* within the MlirOperation.
478   pybind11::object getCapsule();
479 
480   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
481   /// Ownership of the underlying MlirOperation is taken by calling this
482   /// function.
483   static pybind11::object createFromCapsule(pybind11::object capsule);
484 
485   /// Creates an operation. See corresponding python docstring.
486   static pybind11::object
487   create(const std::string &name, llvm::Optional<std::vector<PyType *>> results,
488          llvm::Optional<std::vector<PyValue *>> operands,
489          llvm::Optional<pybind11::dict> attributes,
490          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
491          DefaultingPyLocation location, const pybind11::object &ip);
492 
493   /// Creates an OpView suitable for this operation.
494   pybind11::object createOpView();
495 
496   /// Erases the underlying MlirOperation, removes its pointer from the
497   /// parent context's live operations map, and sets the valid bit false.
498   void erase();
499 
500 private:
501   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
502   static PyOperationRef createInstance(PyMlirContextRef contextRef,
503                                        MlirOperation operation,
504                                        pybind11::object parentKeepAlive);
505 
506   MlirOperation operation;
507   pybind11::handle handle;
508   // Keeps the parent alive, regardless of whether it is an Operation or
509   // Module.
510   // TODO: As implemented, this facility is only sufficient for modeling the
511   // trivial module parent back-reference. Generalize this to also account for
512   // transitions from detached to attached and address TODOs in the
513   // ir_operation.py regarding testing corresponding lifetime guarantees.
514   pybind11::object parentKeepAlive;
515   bool attached = true;
516   bool valid = true;
517 
518   friend class PyOperationBase;
519   friend class PySymbolTable;
520 };
521 
522 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
523 /// providing more instance-specific accessors and serve as the base class for
524 /// custom ODS-style operation classes. Since this class is subclass on the
525 /// python side, it must present an __init__ method that operates in pure
526 /// python types.
527 class PyOpView : public PyOperationBase {
528 public:
529   PyOpView(const pybind11::object &operationObject);
530   PyOperation &getOperation() override { return operation; }
531 
532   static pybind11::object createRawSubclass(const pybind11::object &userClass);
533 
534   pybind11::object getOperationObject() { return operationObject; }
535 
536   static pybind11::object
537   buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
538                pybind11::list operandList,
539                llvm::Optional<pybind11::dict> attributes,
540                llvm::Optional<std::vector<PyBlock *>> successors,
541                llvm::Optional<int> regions, DefaultingPyLocation location,
542                const pybind11::object &maybeIp);
543 
544 private:
545   PyOperation &operation;           // For efficient, cast-free access from C++
546   pybind11::object operationObject; // Holds the reference.
547 };
548 
549 /// Wrapper around an MlirRegion.
550 /// Regions are managed completely by their containing operation. Unlike the
551 /// C++ API, the python API does not support detached regions.
552 class PyRegion {
553 public:
554   PyRegion(PyOperationRef parentOperation, MlirRegion region)
555       : parentOperation(std::move(parentOperation)), region(region) {
556     assert(!mlirRegionIsNull(region) && "python region cannot be null");
557   }
558   operator MlirRegion() const { return region; }
559 
560   MlirRegion get() { return region; }
561   PyOperationRef &getParentOperation() { return parentOperation; }
562 
563   void checkValid() { return parentOperation->checkValid(); }
564 
565 private:
566   PyOperationRef parentOperation;
567   MlirRegion region;
568 };
569 
570 /// Wrapper around an MlirBlock.
571 /// Blocks are managed completely by their containing operation. Unlike the
572 /// C++ API, the python API does not support detached blocks.
573 class PyBlock {
574 public:
575   PyBlock(PyOperationRef parentOperation, MlirBlock block)
576       : parentOperation(std::move(parentOperation)), block(block) {
577     assert(!mlirBlockIsNull(block) && "python block cannot be null");
578   }
579 
580   MlirBlock get() { return block; }
581   PyOperationRef &getParentOperation() { return parentOperation; }
582 
583   void checkValid() { return parentOperation->checkValid(); }
584 
585 private:
586   PyOperationRef parentOperation;
587   MlirBlock block;
588 };
589 
590 /// An insertion point maintains a pointer to a Block and a reference operation.
591 /// Calls to insert() will insert a new operation before the
592 /// reference operation. If the reference operation is null, then appends to
593 /// the end of the block.
594 class PyInsertionPoint {
595 public:
596   /// Creates an insertion point positioned after the last operation in the
597   /// block, but still inside the block.
598   PyInsertionPoint(PyBlock &block);
599   /// Creates an insertion point positioned before a reference operation.
600   PyInsertionPoint(PyOperationBase &beforeOperationBase);
601 
602   /// Shortcut to create an insertion point at the beginning of the block.
603   static PyInsertionPoint atBlockBegin(PyBlock &block);
604   /// Shortcut to create an insertion point before the block terminator.
605   static PyInsertionPoint atBlockTerminator(PyBlock &block);
606 
607   /// Inserts an operation.
608   void insert(PyOperationBase &operationBase);
609 
610   /// Enter and exit the context manager.
611   pybind11::object contextEnter();
612   void contextExit(const pybind11::object &excType,
613                    const pybind11::object &excVal,
614                    const pybind11::object &excTb);
615 
616   PyBlock &getBlock() { return block; }
617 
618 private:
619   // Trampoline constructor that avoids null initializing members while
620   // looking up parents.
621   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
622       : refOperation(std::move(refOperation)), block(std::move(block)) {}
623 
624   llvm::Optional<PyOperationRef> refOperation;
625   PyBlock block;
626 };
627 /// Wrapper around the generic MlirType.
628 /// The lifetime of a type is bound by the PyContext that created it.
629 class PyType : public BaseContextObject {
630 public:
631   PyType(PyMlirContextRef contextRef, MlirType type)
632       : BaseContextObject(std::move(contextRef)), type(type) {}
633   bool operator==(const PyType &other);
634   operator MlirType() const { return type; }
635   MlirType get() const { return type; }
636 
637   /// Gets a capsule wrapping the void* within the MlirType.
638   pybind11::object getCapsule();
639 
640   /// Creates a PyType from the MlirType wrapped by a capsule.
641   /// Note that PyType instances are uniqued, so the returned object
642   /// may be a pre-existing object. Ownership of the underlying MlirType
643   /// is taken by calling this function.
644   static PyType createFromCapsule(pybind11::object capsule);
645 
646 private:
647   MlirType type;
648 };
649 
650 /// CRTP base classes for Python types that subclass Type and should be
651 /// castable from it (i.e. via something like IntegerType(t)).
652 /// By default, type class hierarchies are one level deep (i.e. a
653 /// concrete type class extends PyType); however, intermediate python-visible
654 /// base classes can be modeled by specifying a BaseTy.
655 template <typename DerivedTy, typename BaseTy = PyType>
656 class PyConcreteType : public BaseTy {
657 public:
658   // Derived classes must define statics for:
659   //   IsAFunctionTy isaFunction
660   //   const char *pyClassName
661   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
662   using IsAFunctionTy = bool (*)(MlirType);
663 
664   PyConcreteType() = default;
665   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
666       : BaseTy(std::move(contextRef), t) {}
667   PyConcreteType(PyType &orig)
668       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
669 
670   static MlirType castFrom(PyType &orig) {
671     if (!DerivedTy::isaFunction(orig)) {
672       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
673       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
674                                              DerivedTy::pyClassName +
675                                              " (from " + origRepr + ")");
676     }
677     return orig;
678   }
679 
680   static void bind(pybind11::module &m) {
681     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
682     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
683             pybind11::arg("cast_from_type"));
684     cls.def_static(
685         "isinstance",
686         [](PyType &otherType) -> bool {
687           return DerivedTy::isaFunction(otherType);
688         },
689         pybind11::arg("other"));
690     DerivedTy::bindDerived(cls);
691   }
692 
693   /// Implemented by derived classes to add methods to the Python subclass.
694   static void bindDerived(ClassTy &m) {}
695 };
696 
697 /// Wrapper around the generic MlirAttribute.
698 /// The lifetime of a type is bound by the PyContext that created it.
699 class PyAttribute : public BaseContextObject {
700 public:
701   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
702       : BaseContextObject(std::move(contextRef)), attr(attr) {}
703   bool operator==(const PyAttribute &other);
704   operator MlirAttribute() const { return attr; }
705   MlirAttribute get() const { return attr; }
706 
707   /// Gets a capsule wrapping the void* within the MlirAttribute.
708   pybind11::object getCapsule();
709 
710   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
711   /// Note that PyAttribute instances are uniqued, so the returned object
712   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
713   /// is taken by calling this function.
714   static PyAttribute createFromCapsule(pybind11::object capsule);
715 
716 private:
717   MlirAttribute attr;
718 };
719 
720 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
721 /// TODO: Refactor this and the C-API to be based on an Identifier owned
722 /// by the context so as to avoid ownership issues here.
723 class PyNamedAttribute {
724 public:
725   /// Constructs a PyNamedAttr that retains an owned name. This should be
726   /// used in any code that originates an MlirNamedAttribute from a python
727   /// string.
728   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
729   /// passed attribute.
730   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
731 
732   MlirNamedAttribute namedAttr;
733 
734 private:
735   // Since the MlirNamedAttr contains an internal pointer to the actual
736   // memory of the owned string, it must be heap allocated to remain valid.
737   // Otherwise, strings that fit within the small object optimization threshold
738   // will have their memory address change as the containing object is moved,
739   // resulting in an invalid aliased pointer.
740   std::unique_ptr<std::string> ownedName;
741 };
742 
743 /// CRTP base classes for Python attributes that subclass Attribute and should
744 /// be castable from it (i.e. via something like StringAttr(attr)).
745 /// By default, attribute class hierarchies are one level deep (i.e. a
746 /// concrete attribute class extends PyAttribute); however, intermediate
747 /// python-visible base classes can be modeled by specifying a BaseTy.
748 template <typename DerivedTy, typename BaseTy = PyAttribute>
749 class PyConcreteAttribute : public BaseTy {
750 public:
751   // Derived classes must define statics for:
752   //   IsAFunctionTy isaFunction
753   //   const char *pyClassName
754   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
755   using IsAFunctionTy = bool (*)(MlirAttribute);
756 
757   PyConcreteAttribute() = default;
758   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
759       : BaseTy(std::move(contextRef), attr) {}
760   PyConcreteAttribute(PyAttribute &orig)
761       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
762 
763   static MlirAttribute castFrom(PyAttribute &orig) {
764     if (!DerivedTy::isaFunction(orig)) {
765       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
766       throw SetPyError(PyExc_ValueError,
767                        llvm::Twine("Cannot cast attribute to ") +
768                            DerivedTy::pyClassName + " (from " + origRepr + ")");
769     }
770     return orig;
771   }
772 
773   static void bind(pybind11::module &m) {
774     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
775                        pybind11::module_local());
776     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
777             pybind11::arg("cast_from_attr"));
778     cls.def_static(
779         "isinstance",
780         [](PyAttribute &otherAttr) -> bool {
781           return DerivedTy::isaFunction(otherAttr);
782         },
783         pybind11::arg("other"));
784     cls.def_property_readonly("type", [](PyAttribute &attr) {
785       return PyType(attr.getContext(), mlirAttributeGetType(attr));
786     });
787     DerivedTy::bindDerived(cls);
788   }
789 
790   /// Implemented by derived classes to add methods to the Python subclass.
791   static void bindDerived(ClassTy &m) {}
792 };
793 
794 /// Wrapper around the generic MlirValue.
795 /// Values are managed completely by the operation that resulted in their
796 /// definition. For op result value, this is the operation that defines the
797 /// value. For block argument values, this is the operation that contains the
798 /// block to which the value is an argument (blocks cannot be detached in Python
799 /// bindings so such operation always exists).
800 class PyValue {
801 public:
802   PyValue(PyOperationRef parentOperation, MlirValue value)
803       : parentOperation(parentOperation), value(value) {}
804   operator MlirValue() const { return value; }
805 
806   MlirValue get() { return value; }
807   PyOperationRef &getParentOperation() { return parentOperation; }
808 
809   void checkValid() { return parentOperation->checkValid(); }
810 
811   /// Gets a capsule wrapping the void* within the MlirValue.
812   pybind11::object getCapsule();
813 
814   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
815   /// the underlying MlirValue is still tied to the owning operation.
816   static PyValue createFromCapsule(pybind11::object capsule);
817 
818 private:
819   PyOperationRef parentOperation;
820   MlirValue value;
821 };
822 
823 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
824 class PyAffineExpr : public BaseContextObject {
825 public:
826   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
827       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
828   bool operator==(const PyAffineExpr &other);
829   operator MlirAffineExpr() const { return affineExpr; }
830   MlirAffineExpr get() const { return affineExpr; }
831 
832   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
833   pybind11::object getCapsule();
834 
835   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
836   /// Note that PyAffineExpr instances are uniqued, so the returned object
837   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
838   /// is taken by calling this function.
839   static PyAffineExpr createFromCapsule(pybind11::object capsule);
840 
841   PyAffineExpr add(const PyAffineExpr &other) const;
842   PyAffineExpr mul(const PyAffineExpr &other) const;
843   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
844   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
845   PyAffineExpr mod(const PyAffineExpr &other) const;
846 
847 private:
848   MlirAffineExpr affineExpr;
849 };
850 
851 class PyAffineMap : public BaseContextObject {
852 public:
853   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
854       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
855   bool operator==(const PyAffineMap &other);
856   operator MlirAffineMap() const { return affineMap; }
857   MlirAffineMap get() const { return affineMap; }
858 
859   /// Gets a capsule wrapping the void* within the MlirAffineMap.
860   pybind11::object getCapsule();
861 
862   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
863   /// Note that PyAffineMap instances are uniqued, so the returned object
864   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
865   /// is taken by calling this function.
866   static PyAffineMap createFromCapsule(pybind11::object capsule);
867 
868 private:
869   MlirAffineMap affineMap;
870 };
871 
872 class PyIntegerSet : public BaseContextObject {
873 public:
874   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
875       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
876   bool operator==(const PyIntegerSet &other);
877   operator MlirIntegerSet() const { return integerSet; }
878   MlirIntegerSet get() const { return integerSet; }
879 
880   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
881   pybind11::object getCapsule();
882 
883   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
884   /// Note that PyIntegerSet instances may be uniqued, so the returned object
885   /// may be a pre-existing object. Integer sets are owned by the context.
886   static PyIntegerSet createFromCapsule(pybind11::object capsule);
887 
888 private:
889   MlirIntegerSet integerSet;
890 };
891 
892 /// Bindings for MLIR symbol tables.
893 class PySymbolTable {
894 public:
895   /// Constructs a symbol table for the given operation.
896   explicit PySymbolTable(PyOperationBase &operation);
897 
898   /// Destroys the symbol table.
899   ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
900 
901   /// Returns the symbol (opview) with the given name, throws if there is no
902   /// such symbol in the table.
903   pybind11::object dunderGetItem(const std::string &name);
904 
905   /// Removes the given operation from the symbol table and erases it.
906   void erase(PyOperationBase &symbol);
907 
908   /// Removes the operation with the given name from the symbol table and erases
909   /// it, throws if there is no such symbol in the table.
910   void dunderDel(const std::string &name);
911 
912   /// Inserts the given operation into the symbol table. The operation must have
913   /// the symbol trait.
914   PyAttribute insert(PyOperationBase &symbol);
915 
916   /// Gets and sets the name of a symbol op.
917   static PyAttribute getSymbolName(PyOperationBase &symbol);
918   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
919 
920   /// Gets and sets the visibility of a symbol op.
921   static PyAttribute getVisibility(PyOperationBase &symbol);
922   static void setVisibility(PyOperationBase &symbol,
923                             const std::string &visibility);
924 
925   /// Replaces all symbol uses within an operation. See the API
926   /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
927   static void replaceAllSymbolUses(const std::string &oldSymbol,
928                                    const std::string &newSymbol,
929                                    PyOperationBase &from);
930 
931   /// Walks all symbol tables under and including 'from'.
932   static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
933                                pybind11::object callback);
934 
935   /// Casts the bindings class into the C API structure.
936   operator MlirSymbolTable() { return symbolTable; }
937 
938 private:
939   PyOperationRef operation;
940   MlirSymbolTable symbolTable;
941 };
942 
943 void populateIRAffine(pybind11::module &m);
944 void populateIRAttributes(pybind11::module &m);
945 void populateIRCore(pybind11::module &m);
946 void populateIRInterfaces(pybind11::module &m);
947 void populateIRTypes(pybind11::module &m);
948 
949 } // namespace python
950 } // namespace mlir
951 
952 namespace pybind11 {
953 namespace detail {
954 
955 template <>
956 struct type_caster<mlir::python::DefaultingPyMlirContext>
957     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
958 template <>
959 struct type_caster<mlir::python::DefaultingPyLocation>
960     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
961 
962 } // namespace detail
963 } // namespace pybind11
964 
965 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
966