1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <vector>
13 
14 #include "PybindUtils.h"
15 
16 #include "mlir-c/AffineExpr.h"
17 #include "mlir-c/AffineMap.h"
18 #include "mlir-c/IR.h"
19 #include "mlir-c/IntegerSet.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/Optional.h"
22 
23 namespace mlir {
24 namespace python {
25 
26 class PyBlock;
27 class PyInsertionPoint;
28 class PyLocation;
29 class DefaultingPyLocation;
30 class PyMlirContext;
31 class DefaultingPyMlirContext;
32 class PyModule;
33 class PyOperation;
34 class PyType;
35 class PyValue;
36 
37 /// Template for a reference to a concrete type which captures a python
38 /// reference to its underlying python object.
39 template <typename T>
40 class PyObjectRef {
41 public:
42   PyObjectRef(T *referrent, pybind11::object object)
43       : referrent(referrent), object(std::move(object)) {
44     assert(this->referrent &&
45            "cannot construct PyObjectRef with null referrent");
46     assert(this->object && "cannot construct PyObjectRef with null object");
47   }
48   PyObjectRef(PyObjectRef &&other)
49       : referrent(other.referrent), object(std::move(other.object)) {
50     other.referrent = nullptr;
51     assert(!other.object);
52   }
53   PyObjectRef(const PyObjectRef &other)
54       : referrent(other.referrent), object(other.object /* copies */) {}
55   ~PyObjectRef() {}
56 
57   int getRefCount() {
58     if (!object)
59       return 0;
60     return object.ref_count();
61   }
62 
63   /// Releases the object held by this instance, returning it.
64   /// This is the proper thing to return from a function that wants to return
65   /// the reference. Note that this does not work from initializers.
66   pybind11::object releaseObject() {
67     assert(referrent && object);
68     referrent = nullptr;
69     auto stolen = std::move(object);
70     return stolen;
71   }
72 
73   T *get() { return referrent; }
74   T *operator->() {
75     assert(referrent && object);
76     return referrent;
77   }
78   pybind11::object getObject() {
79     assert(referrent && object);
80     return object;
81   }
82   operator bool() const { return referrent && object; }
83 
84 private:
85   T *referrent;
86   pybind11::object object;
87 };
88 
89 /// Tracks an entry in the thread context stack. New entries are pushed onto
90 /// here for each with block that activates a new InsertionPoint, Context or
91 /// Location.
92 ///
93 /// Pushing either a Location or InsertionPoint also pushes its associated
94 /// Context. Pushing a Context will not modify the Location or InsertionPoint
95 /// unless if they are from a different context, in which case, they are
96 /// cleared.
97 class PyThreadContextEntry {
98 public:
99   enum class FrameKind {
100     Context,
101     InsertionPoint,
102     Location,
103   };
104 
105   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
106                        pybind11::object insertionPoint,
107                        pybind11::object location)
108       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
109         location(std::move(location)), frameKind(frameKind) {}
110 
111   /// Gets the top of stack context and return nullptr if not defined.
112   static PyMlirContext *getDefaultContext();
113 
114   /// Gets the top of stack insertion point and return nullptr if not defined.
115   static PyInsertionPoint *getDefaultInsertionPoint();
116 
117   /// Gets the top of stack location and returns nullptr if not defined.
118   static PyLocation *getDefaultLocation();
119 
120   PyMlirContext *getContext();
121   PyInsertionPoint *getInsertionPoint();
122   PyLocation *getLocation();
123   FrameKind getFrameKind() { return frameKind; }
124 
125   /// Stack management.
126   static PyThreadContextEntry *getTopOfStack();
127   static pybind11::object pushContext(PyMlirContext &context);
128   static void popContext(PyMlirContext &context);
129   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
130   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
131   static pybind11::object pushLocation(PyLocation &location);
132   static void popLocation(PyLocation &location);
133 
134   /// Gets the thread local stack.
135   static std::vector<PyThreadContextEntry> &getStack();
136 
137 private:
138   static void push(FrameKind frameKind, pybind11::object context,
139                    pybind11::object insertionPoint, pybind11::object location);
140 
141   /// An object reference to the PyContext.
142   pybind11::object context;
143   /// An object reference to the current insertion point.
144   pybind11::object insertionPoint;
145   /// An object reference to the current location.
146   pybind11::object location;
147   // The kind of push that was performed.
148   FrameKind frameKind;
149 };
150 
151 /// Wrapper around MlirContext.
152 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
153 class PyMlirContext {
154 public:
155   PyMlirContext() = delete;
156   PyMlirContext(const PyMlirContext &) = delete;
157   PyMlirContext(PyMlirContext &&) = delete;
158 
159   /// For the case of a python __init__ (py::init) method, pybind11 is quite
160   /// strict about needing to return a pointer that is not yet associated to
161   /// an py::object. Since the forContext() method acts like a pool, possibly
162   /// returning a recycled context, it does not satisfy this need. The usual
163   /// way in python to accomplish such a thing is to override __new__, but
164   /// that is also not supported by pybind11. Instead, we use this entry
165   /// point which always constructs a fresh context (which cannot alias an
166   /// existing one because it is fresh).
167   static PyMlirContext *createNewContextForInit();
168 
169   /// Returns a context reference for the singleton PyMlirContext wrapper for
170   /// the given context.
171   static PyMlirContextRef forContext(MlirContext context);
172   ~PyMlirContext();
173 
174   /// Accesses the underlying MlirContext.
175   MlirContext get() { return context; }
176 
177   /// Gets a strong reference to this context, which will ensure it is kept
178   /// alive for the life of the reference.
179   PyMlirContextRef getRef() {
180     return PyMlirContextRef(this, pybind11::cast(this));
181   }
182 
183   /// Gets a capsule wrapping the void* within the MlirContext.
184   pybind11::object getCapsule();
185 
186   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
187   /// Note that PyMlirContext instances are uniqued, so the returned object
188   /// may be a pre-existing object. Ownership of the underlying MlirContext
189   /// is taken by calling this function.
190   static pybind11::object createFromCapsule(pybind11::object capsule);
191 
192   /// Gets the count of live context objects. Used for testing.
193   static size_t getLiveCount();
194 
195   /// Gets the count of live operations associated with this context.
196   /// Used for testing.
197   size_t getLiveOperationCount();
198 
199   /// Gets the count of live modules associated with this context.
200   /// Used for testing.
201   size_t getLiveModuleCount();
202 
203   /// Enter and exit the context manager.
204   pybind11::object contextEnter();
205   void contextExit(pybind11::object excType, pybind11::object excVal,
206                    pybind11::object excTb);
207 
208 private:
209   PyMlirContext(MlirContext context);
210   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
211   // preserving the relationship that an MlirContext maps to a single
212   // PyMlirContext wrapper. This could be replaced in the future with an
213   // extension mechanism on the MlirContext for stashing user pointers.
214   // Note that this holds a handle, which does not imply ownership.
215   // Mappings will be removed when the context is destructed.
216   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
217   static LiveContextMap &getLiveContexts();
218 
219   // Interns all live modules associated with this context. Modules tracked
220   // in this map are valid. When a module is invalidated, it is removed
221   // from this map, and while it still exists as an instance, any
222   // attempt to access it will raise an error.
223   using LiveModuleMap =
224       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
225   LiveModuleMap liveModules;
226 
227   // Interns all live operations associated with this context. Operations
228   // tracked in this map are valid. When an operation is invalidated, it is
229   // removed from this map, and while it still exists as an instance, any
230   // attempt to access it will raise an error.
231   using LiveOperationMap =
232       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
233   LiveOperationMap liveOperations;
234 
235   MlirContext context;
236   friend class PyModule;
237   friend class PyOperation;
238 };
239 
240 /// Used in function arguments when None should resolve to the current context
241 /// manager set instance.
242 class DefaultingPyMlirContext
243     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
244 public:
245   using Defaulting::Defaulting;
246   static constexpr const char kTypeDescription[] =
247       "[ThreadContextAware] mlir.ir.Context";
248   static PyMlirContext &resolve();
249 };
250 
251 /// Base class for all objects that directly or indirectly depend on an
252 /// MlirContext. The lifetime of the context will extend at least to the
253 /// lifetime of these instances.
254 /// Immutable objects that depend on a context extend this directly.
255 class BaseContextObject {
256 public:
257   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
258     assert(this->contextRef &&
259            "context object constructed with null context ref");
260   }
261 
262   /// Accesses the context reference.
263   PyMlirContextRef &getContext() { return contextRef; }
264 
265 private:
266   PyMlirContextRef contextRef;
267 };
268 
269 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
270 /// order to differentiate it from the `Dialect` base class which is extended by
271 /// plugins which extend dialect functionality through extension python code.
272 /// This should be seen as the "low-level" object and `Dialect` as the
273 /// high-level, user facing object.
274 class PyDialectDescriptor : public BaseContextObject {
275 public:
276   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
277       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
278 
279   MlirDialect get() { return dialect; }
280 
281 private:
282   MlirDialect dialect;
283 };
284 
285 /// User-level object for accessing dialects with dotted syntax such as:
286 ///   ctx.dialect.std
287 class PyDialects : public BaseContextObject {
288 public:
289   PyDialects(PyMlirContextRef contextRef)
290       : BaseContextObject(std::move(contextRef)) {}
291 
292   MlirDialect getDialectForKey(const std::string &key, bool attrError);
293 };
294 
295 /// User-level dialect object. For dialects that have a registered extension,
296 /// this will be the base class of the extension dialect type. For un-extended,
297 /// objects of this type will be returned directly.
298 class PyDialect {
299 public:
300   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
301 
302   pybind11::object getDescriptor() { return descriptor; }
303 
304 private:
305   pybind11::object descriptor;
306 };
307 
308 /// Wrapper around an MlirLocation.
309 class PyLocation : public BaseContextObject {
310 public:
311   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
312       : BaseContextObject(std::move(contextRef)), loc(loc) {}
313 
314   operator MlirLocation() const { return loc; }
315   MlirLocation get() const { return loc; }
316 
317   /// Enter and exit the context manager.
318   pybind11::object contextEnter();
319   void contextExit(pybind11::object excType, pybind11::object excVal,
320                    pybind11::object excTb);
321 
322   /// Gets a capsule wrapping the void* within the MlirLocation.
323   pybind11::object getCapsule();
324 
325   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
326   /// Note that PyLocation instances are uniqued, so the returned object
327   /// may be a pre-existing object. Ownership of the underlying MlirLocation
328   /// is taken by calling this function.
329   static PyLocation createFromCapsule(pybind11::object capsule);
330 
331 private:
332   MlirLocation loc;
333 };
334 
335 /// Used in function arguments when None should resolve to the current context
336 /// manager set instance.
337 class DefaultingPyLocation
338     : public Defaulting<DefaultingPyLocation, PyLocation> {
339 public:
340   using Defaulting::Defaulting;
341   static constexpr const char kTypeDescription[] =
342       "[ThreadContextAware] mlir.ir.Location";
343   static PyLocation &resolve();
344 
345   operator MlirLocation() const { return *get(); }
346 };
347 
348 /// Wrapper around MlirModule.
349 /// This is the top-level, user-owned object that contains regions/ops/blocks.
350 class PyModule;
351 using PyModuleRef = PyObjectRef<PyModule>;
352 class PyModule : public BaseContextObject {
353 public:
354   /// Returns a PyModule reference for the given MlirModule. This may return
355   /// a pre-existing or new object.
356   static PyModuleRef forModule(MlirModule module);
357   PyModule(PyModule &) = delete;
358   PyModule(PyMlirContext &&) = delete;
359   ~PyModule();
360 
361   /// Gets the backing MlirModule.
362   MlirModule get() { return module; }
363 
364   /// Gets a strong reference to this module.
365   PyModuleRef getRef() {
366     return PyModuleRef(this,
367                        pybind11::reinterpret_borrow<pybind11::object>(handle));
368   }
369 
370   /// Gets a capsule wrapping the void* within the MlirModule.
371   /// Note that the module does not (yet) provide a corresponding factory for
372   /// constructing from a capsule as that would require uniquing PyModule
373   /// instances, which is not currently done.
374   pybind11::object getCapsule();
375 
376   /// Creates a PyModule from the MlirModule wrapped by a capsule.
377   /// Note that PyModule instances are uniqued, so the returned object
378   /// may be a pre-existing object. Ownership of the underlying MlirModule
379   /// is taken by calling this function.
380   static pybind11::object createFromCapsule(pybind11::object capsule);
381 
382 private:
383   PyModule(PyMlirContextRef contextRef, MlirModule module);
384   MlirModule module;
385   pybind11::handle handle;
386 };
387 
388 /// Base class for PyOperation and PyOpView which exposes the primary, user
389 /// visible methods for manipulating it.
390 class PyOperationBase {
391 public:
392   virtual ~PyOperationBase() = default;
393   /// Implements the bound 'print' method and helps with others.
394   void print(pybind11::object fileObject, bool binary,
395              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
396              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope);
397   pybind11::object getAsm(bool binary,
398                           llvm::Optional<int64_t> largeElementsLimit,
399                           bool enableDebugInfo, bool prettyDebugInfo,
400                           bool printGenericOpForm, bool useLocalScope);
401 
402   /// Moves the operation before or after the other operation.
403   void moveAfter(PyOperationBase &other);
404   void moveBefore(PyOperationBase &other);
405 
406   /// Each must provide access to the raw Operation.
407   virtual PyOperation &getOperation() = 0;
408 };
409 
410 /// Wrapper around PyOperation.
411 /// Operations exist in either an attached (dependent) or detached (top-level)
412 /// state. In the detached state (as on creation), an operation is owned by
413 /// the creator and its lifetime extends either until its reference count
414 /// drops to zero or it is attached to a parent, at which point its lifetime
415 /// is bounded by its top-level parent reference.
416 class PyOperation;
417 using PyOperationRef = PyObjectRef<PyOperation>;
418 class PyOperation : public PyOperationBase, public BaseContextObject {
419 public:
420   ~PyOperation();
421   PyOperation &getOperation() override { return *this; }
422 
423   /// Returns a PyOperation for the given MlirOperation, optionally associating
424   /// it with a parentKeepAlive.
425   static PyOperationRef
426   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
427                pybind11::object parentKeepAlive = pybind11::object());
428 
429   /// Creates a detached operation. The operation must not be associated with
430   /// any existing live operation.
431   static PyOperationRef
432   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
433                  pybind11::object parentKeepAlive = pybind11::object());
434 
435   /// Detaches the operation from its parent block and updates its state
436   /// accordingly.
437   void detachFromParent() {
438     mlirOperationRemoveFromParent(getOperation());
439     setDetached();
440     parentKeepAlive = pybind11::object();
441   }
442 
443   /// Gets the backing operation.
444   operator MlirOperation() const { return get(); }
445   MlirOperation get() const {
446     checkValid();
447     return operation;
448   }
449 
450   PyOperationRef getRef() {
451     return PyOperationRef(
452         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
453   }
454 
455   bool isAttached() { return attached; }
456   void setAttached(pybind11::object parent = pybind11::object()) {
457     assert(!attached && "operation already attached");
458     attached = true;
459   }
460   void setDetached() {
461     assert(attached && "operation already detached");
462     attached = false;
463   }
464   void checkValid() const;
465 
466   /// Gets the owning block or raises an exception if the operation has no
467   /// owning block.
468   PyBlock getBlock();
469 
470   /// Gets the parent operation or raises an exception if the operation has
471   /// no parent.
472   llvm::Optional<PyOperationRef> getParentOperation();
473 
474   /// Gets a capsule wrapping the void* within the MlirOperation.
475   pybind11::object getCapsule();
476 
477   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
478   /// Ownership of the underlying MlirOperation is taken by calling this
479   /// function.
480   static pybind11::object createFromCapsule(pybind11::object capsule);
481 
482   /// Creates an operation. See corresponding python docstring.
483   static pybind11::object
484   create(std::string name, llvm::Optional<std::vector<PyType *>> results,
485          llvm::Optional<std::vector<PyValue *>> operands,
486          llvm::Optional<pybind11::dict> attributes,
487          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
488          DefaultingPyLocation location, pybind11::object ip);
489 
490   /// Creates an OpView suitable for this operation.
491   pybind11::object createOpView();
492 
493   /// Erases the underlying MlirOperation, removes its pointer from the
494   /// parent context's live operations map, and sets the valid bit false.
495   void erase();
496 
497 private:
498   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
499   static PyOperationRef createInstance(PyMlirContextRef contextRef,
500                                        MlirOperation operation,
501                                        pybind11::object parentKeepAlive);
502 
503   MlirOperation operation;
504   pybind11::handle handle;
505   // Keeps the parent alive, regardless of whether it is an Operation or
506   // Module.
507   // TODO: As implemented, this facility is only sufficient for modeling the
508   // trivial module parent back-reference. Generalize this to also account for
509   // transitions from detached to attached and address TODOs in the
510   // ir_operation.py regarding testing corresponding lifetime guarantees.
511   pybind11::object parentKeepAlive;
512   bool attached = true;
513   bool valid = true;
514 
515   friend class PyOperationBase;
516 };
517 
518 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
519 /// providing more instance-specific accessors and serve as the base class for
520 /// custom ODS-style operation classes. Since this class is subclass on the
521 /// python side, it must present an __init__ method that operates in pure
522 /// python types.
523 class PyOpView : public PyOperationBase {
524 public:
525   PyOpView(pybind11::object operationObject);
526   PyOperation &getOperation() override { return operation; }
527 
528   static pybind11::object createRawSubclass(pybind11::object userClass);
529 
530   pybind11::object getOperationObject() { return operationObject; }
531 
532   static pybind11::object
533   buildGeneric(pybind11::object cls, pybind11::list resultTypeList,
534                pybind11::list operandList,
535                llvm::Optional<pybind11::dict> attributes,
536                llvm::Optional<std::vector<PyBlock *>> successors,
537                llvm::Optional<int> regions, DefaultingPyLocation location,
538                pybind11::object maybeIp);
539 
540 private:
541   PyOperation &operation;           // For efficient, cast-free access from C++
542   pybind11::object operationObject; // Holds the reference.
543 };
544 
545 /// Wrapper around an MlirRegion.
546 /// Regions are managed completely by their containing operation. Unlike the
547 /// C++ API, the python API does not support detached regions.
548 class PyRegion {
549 public:
550   PyRegion(PyOperationRef parentOperation, MlirRegion region)
551       : parentOperation(std::move(parentOperation)), region(region) {
552     assert(!mlirRegionIsNull(region) && "python region cannot be null");
553   }
554   operator MlirRegion() const { return region; }
555 
556   MlirRegion get() { return region; }
557   PyOperationRef &getParentOperation() { return parentOperation; }
558 
559   void checkValid() { return parentOperation->checkValid(); }
560 
561 private:
562   PyOperationRef parentOperation;
563   MlirRegion region;
564 };
565 
566 /// Wrapper around an MlirBlock.
567 /// Blocks are managed completely by their containing operation. Unlike the
568 /// C++ API, the python API does not support detached blocks.
569 class PyBlock {
570 public:
571   PyBlock(PyOperationRef parentOperation, MlirBlock block)
572       : parentOperation(std::move(parentOperation)), block(block) {
573     assert(!mlirBlockIsNull(block) && "python block cannot be null");
574   }
575 
576   MlirBlock get() { return block; }
577   PyOperationRef &getParentOperation() { return parentOperation; }
578 
579   void checkValid() { return parentOperation->checkValid(); }
580 
581 private:
582   PyOperationRef parentOperation;
583   MlirBlock block;
584 };
585 
586 /// An insertion point maintains a pointer to a Block and a reference operation.
587 /// Calls to insert() will insert a new operation before the
588 /// reference operation. If the reference operation is null, then appends to
589 /// the end of the block.
590 class PyInsertionPoint {
591 public:
592   /// Creates an insertion point positioned after the last operation in the
593   /// block, but still inside the block.
594   PyInsertionPoint(PyBlock &block);
595   /// Creates an insertion point positioned before a reference operation.
596   PyInsertionPoint(PyOperationBase &beforeOperationBase);
597 
598   /// Shortcut to create an insertion point at the beginning of the block.
599   static PyInsertionPoint atBlockBegin(PyBlock &block);
600   /// Shortcut to create an insertion point before the block terminator.
601   static PyInsertionPoint atBlockTerminator(PyBlock &block);
602 
603   /// Inserts an operation.
604   void insert(PyOperationBase &operationBase);
605 
606   /// Enter and exit the context manager.
607   pybind11::object contextEnter();
608   void contextExit(pybind11::object excType, pybind11::object excVal,
609                    pybind11::object excTb);
610 
611   PyBlock &getBlock() { return block; }
612 
613 private:
614   // Trampoline constructor that avoids null initializing members while
615   // looking up parents.
616   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
617       : refOperation(std::move(refOperation)), block(std::move(block)) {}
618 
619   llvm::Optional<PyOperationRef> refOperation;
620   PyBlock block;
621 };
622 /// Wrapper around the generic MlirType.
623 /// The lifetime of a type is bound by the PyContext that created it.
624 class PyType : public BaseContextObject {
625 public:
626   PyType(PyMlirContextRef contextRef, MlirType type)
627       : BaseContextObject(std::move(contextRef)), type(type) {}
628   bool operator==(const PyType &other);
629   operator MlirType() const { return type; }
630   MlirType get() const { return type; }
631 
632   /// Gets a capsule wrapping the void* within the MlirType.
633   pybind11::object getCapsule();
634 
635   /// Creates a PyType from the MlirType wrapped by a capsule.
636   /// Note that PyType instances are uniqued, so the returned object
637   /// may be a pre-existing object. Ownership of the underlying MlirType
638   /// is taken by calling this function.
639   static PyType createFromCapsule(pybind11::object capsule);
640 
641 private:
642   MlirType type;
643 };
644 
645 /// CRTP base classes for Python types that subclass Type and should be
646 /// castable from it (i.e. via something like IntegerType(t)).
647 /// By default, type class hierarchies are one level deep (i.e. a
648 /// concrete type class extends PyType); however, intermediate python-visible
649 /// base classes can be modeled by specifying a BaseTy.
650 template <typename DerivedTy, typename BaseTy = PyType>
651 class PyConcreteType : public BaseTy {
652 public:
653   // Derived classes must define statics for:
654   //   IsAFunctionTy isaFunction
655   //   const char *pyClassName
656   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
657   using IsAFunctionTy = bool (*)(MlirType);
658 
659   PyConcreteType() = default;
660   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
661       : BaseTy(std::move(contextRef), t) {}
662   PyConcreteType(PyType &orig)
663       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
664 
665   static MlirType castFrom(PyType &orig) {
666     if (!DerivedTy::isaFunction(orig)) {
667       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
668       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
669                                              DerivedTy::pyClassName +
670                                              " (from " + origRepr + ")");
671     }
672     return orig;
673   }
674 
675   static void bind(pybind11::module &m) {
676     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
677     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>());
678     cls.def_static("isinstance", [](PyType &otherType) -> bool {
679       return DerivedTy::isaFunction(otherType);
680     });
681     DerivedTy::bindDerived(cls);
682   }
683 
684   /// Implemented by derived classes to add methods to the Python subclass.
685   static void bindDerived(ClassTy &m) {}
686 };
687 
688 /// Wrapper around the generic MlirAttribute.
689 /// The lifetime of a type is bound by the PyContext that created it.
690 class PyAttribute : public BaseContextObject {
691 public:
692   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
693       : BaseContextObject(std::move(contextRef)), attr(attr) {}
694   bool operator==(const PyAttribute &other);
695   operator MlirAttribute() const { return attr; }
696   MlirAttribute get() const { return attr; }
697 
698   /// Gets a capsule wrapping the void* within the MlirAttribute.
699   pybind11::object getCapsule();
700 
701   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
702   /// Note that PyAttribute instances are uniqued, so the returned object
703   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
704   /// is taken by calling this function.
705   static PyAttribute createFromCapsule(pybind11::object capsule);
706 
707 private:
708   MlirAttribute attr;
709 };
710 
711 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
712 /// TODO: Refactor this and the C-API to be based on an Identifier owned
713 /// by the context so as to avoid ownership issues here.
714 class PyNamedAttribute {
715 public:
716   /// Constructs a PyNamedAttr that retains an owned name. This should be
717   /// used in any code that originates an MlirNamedAttribute from a python
718   /// string.
719   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
720   /// passed attribute.
721   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
722 
723   MlirNamedAttribute namedAttr;
724 
725 private:
726   // Since the MlirNamedAttr contains an internal pointer to the actual
727   // memory of the owned string, it must be heap allocated to remain valid.
728   // Otherwise, strings that fit within the small object optimization threshold
729   // will have their memory address change as the containing object is moved,
730   // resulting in an invalid aliased pointer.
731   std::unique_ptr<std::string> ownedName;
732 };
733 
734 /// CRTP base classes for Python attributes that subclass Attribute and should
735 /// be castable from it (i.e. via something like StringAttr(attr)).
736 /// By default, attribute class hierarchies are one level deep (i.e. a
737 /// concrete attribute class extends PyAttribute); however, intermediate
738 /// python-visible base classes can be modeled by specifying a BaseTy.
739 template <typename DerivedTy, typename BaseTy = PyAttribute>
740 class PyConcreteAttribute : public BaseTy {
741 public:
742   // Derived classes must define statics for:
743   //   IsAFunctionTy isaFunction
744   //   const char *pyClassName
745   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
746   using IsAFunctionTy = bool (*)(MlirAttribute);
747 
748   PyConcreteAttribute() = default;
749   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
750       : BaseTy(std::move(contextRef), attr) {}
751   PyConcreteAttribute(PyAttribute &orig)
752       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
753 
754   static MlirAttribute castFrom(PyAttribute &orig) {
755     if (!DerivedTy::isaFunction(orig)) {
756       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
757       throw SetPyError(PyExc_ValueError,
758                        llvm::Twine("Cannot cast attribute to ") +
759                            DerivedTy::pyClassName + " (from " + origRepr + ")");
760     }
761     return orig;
762   }
763 
764   static void bind(pybind11::module &m) {
765     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
766                        pybind11::module_local());
767     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>());
768     cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool {
769       return DerivedTy::isaFunction(otherAttr);
770     });
771     cls.def_property_readonly("type", [](PyAttribute &attr) {
772       return PyType(attr.getContext(), mlirAttributeGetType(attr));
773     });
774     DerivedTy::bindDerived(cls);
775   }
776 
777   /// Implemented by derived classes to add methods to the Python subclass.
778   static void bindDerived(ClassTy &m) {}
779 };
780 
781 /// Wrapper around the generic MlirValue.
782 /// Values are managed completely by the operation that resulted in their
783 /// definition. For op result value, this is the operation that defines the
784 /// value. For block argument values, this is the operation that contains the
785 /// block to which the value is an argument (blocks cannot be detached in Python
786 /// bindings so such operation always exists).
787 class PyValue {
788 public:
789   PyValue(PyOperationRef parentOperation, MlirValue value)
790       : parentOperation(parentOperation), value(value) {}
791   operator MlirValue() const { return value; }
792 
793   MlirValue get() { return value; }
794   PyOperationRef &getParentOperation() { return parentOperation; }
795 
796   void checkValid() { return parentOperation->checkValid(); }
797 
798   /// Gets a capsule wrapping the void* within the MlirValue.
799   pybind11::object getCapsule();
800 
801   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
802   /// the underlying MlirValue is still tied to the owning operation.
803   static PyValue createFromCapsule(pybind11::object capsule);
804 
805 private:
806   PyOperationRef parentOperation;
807   MlirValue value;
808 };
809 
810 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
811 class PyAffineExpr : public BaseContextObject {
812 public:
813   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
814       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
815   bool operator==(const PyAffineExpr &other);
816   operator MlirAffineExpr() const { return affineExpr; }
817   MlirAffineExpr get() const { return affineExpr; }
818 
819   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
820   pybind11::object getCapsule();
821 
822   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
823   /// Note that PyAffineExpr instances are uniqued, so the returned object
824   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
825   /// is taken by calling this function.
826   static PyAffineExpr createFromCapsule(pybind11::object capsule);
827 
828   PyAffineExpr add(const PyAffineExpr &other) const;
829   PyAffineExpr mul(const PyAffineExpr &other) const;
830   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
831   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
832   PyAffineExpr mod(const PyAffineExpr &other) const;
833 
834 private:
835   MlirAffineExpr affineExpr;
836 };
837 
838 class PyAffineMap : public BaseContextObject {
839 public:
840   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
841       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
842   bool operator==(const PyAffineMap &other);
843   operator MlirAffineMap() const { return affineMap; }
844   MlirAffineMap get() const { return affineMap; }
845 
846   /// Gets a capsule wrapping the void* within the MlirAffineMap.
847   pybind11::object getCapsule();
848 
849   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
850   /// Note that PyAffineMap instances are uniqued, so the returned object
851   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
852   /// is taken by calling this function.
853   static PyAffineMap createFromCapsule(pybind11::object capsule);
854 
855 private:
856   MlirAffineMap affineMap;
857 };
858 
859 class PyIntegerSet : public BaseContextObject {
860 public:
861   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
862       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
863   bool operator==(const PyIntegerSet &other);
864   operator MlirIntegerSet() const { return integerSet; }
865   MlirIntegerSet get() const { return integerSet; }
866 
867   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
868   pybind11::object getCapsule();
869 
870   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
871   /// Note that PyIntegerSet instances may be uniqued, so the returned object
872   /// may be a pre-existing object. Integer sets are owned by the context.
873   static PyIntegerSet createFromCapsule(pybind11::object capsule);
874 
875 private:
876   MlirIntegerSet integerSet;
877 };
878 
879 void populateIRAffine(pybind11::module &m);
880 void populateIRAttributes(pybind11::module &m);
881 void populateIRCore(pybind11::module &m);
882 void populateIRInterfaces(pybind11::module &m);
883 void populateIRTypes(pybind11::module &m);
884 
885 } // namespace python
886 } // namespace mlir
887 
888 namespace pybind11 {
889 namespace detail {
890 
891 template <>
892 struct type_caster<mlir::python::DefaultingPyMlirContext>
893     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
894 template <>
895 struct type_caster<mlir::python::DefaultingPyLocation>
896     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
897 
898 } // namespace detail
899 } // namespace pybind11
900 
901 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
902