1 //===- IRModules.h - IR Submodules of pybind module -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H
10 #define MLIR_BINDINGS_PYTHON_IRMODULES_H
11 
12 #include <utility>
13 #include <vector>
14 
15 #include "PybindUtils.h"
16 
17 #include "mlir-c/AffineExpr.h"
18 #include "mlir-c/AffineMap.h"
19 #include "mlir-c/Diagnostics.h"
20 #include "mlir-c/IR.h"
21 #include "mlir-c/IntegerSet.h"
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/Optional.h"
24 
25 namespace mlir {
26 namespace python {
27 
28 class PyBlock;
29 class PyDiagnostic;
30 class PyDiagnosticHandler;
31 class PyInsertionPoint;
32 class PyLocation;
33 class DefaultingPyLocation;
34 class PyMlirContext;
35 class DefaultingPyMlirContext;
36 class PyModule;
37 class PyOperation;
38 class PyType;
39 class PySymbolTable;
40 class PyValue;
41 
42 /// Template for a reference to a concrete type which captures a python
43 /// reference to its underlying python object.
44 template <typename T>
45 class PyObjectRef {
46 public:
47   PyObjectRef(T *referrent, pybind11::object object)
48       : referrent(referrent), object(std::move(object)) {
49     assert(this->referrent &&
50            "cannot construct PyObjectRef with null referrent");
51     assert(this->object && "cannot construct PyObjectRef with null object");
52   }
53   PyObjectRef(PyObjectRef &&other)
54       : referrent(other.referrent), object(std::move(other.object)) {
55     other.referrent = nullptr;
56     assert(!other.object);
57   }
58   PyObjectRef(const PyObjectRef &other)
59       : referrent(other.referrent), object(other.object /* copies */) {}
60   ~PyObjectRef() = default;
61 
62   int getRefCount() {
63     if (!object)
64       return 0;
65     return object.ref_count();
66   }
67 
68   /// Releases the object held by this instance, returning it.
69   /// This is the proper thing to return from a function that wants to return
70   /// the reference. Note that this does not work from initializers.
71   pybind11::object releaseObject() {
72     assert(referrent && object);
73     referrent = nullptr;
74     auto stolen = std::move(object);
75     return stolen;
76   }
77 
78   T *get() { return referrent; }
79   T *operator->() {
80     assert(referrent && object);
81     return referrent;
82   }
83   pybind11::object getObject() {
84     assert(referrent && object);
85     return object;
86   }
87   operator bool() const { return referrent && object; }
88 
89 private:
90   T *referrent;
91   pybind11::object object;
92 };
93 
94 /// Tracks an entry in the thread context stack. New entries are pushed onto
95 /// here for each with block that activates a new InsertionPoint, Context or
96 /// Location.
97 ///
98 /// Pushing either a Location or InsertionPoint also pushes its associated
99 /// Context. Pushing a Context will not modify the Location or InsertionPoint
100 /// unless if they are from a different context, in which case, they are
101 /// cleared.
102 class PyThreadContextEntry {
103 public:
104   enum class FrameKind {
105     Context,
106     InsertionPoint,
107     Location,
108   };
109 
110   PyThreadContextEntry(FrameKind frameKind, pybind11::object context,
111                        pybind11::object insertionPoint,
112                        pybind11::object location)
113       : context(std::move(context)), insertionPoint(std::move(insertionPoint)),
114         location(std::move(location)), frameKind(frameKind) {}
115 
116   /// Gets the top of stack context and return nullptr if not defined.
117   static PyMlirContext *getDefaultContext();
118 
119   /// Gets the top of stack insertion point and return nullptr if not defined.
120   static PyInsertionPoint *getDefaultInsertionPoint();
121 
122   /// Gets the top of stack location and returns nullptr if not defined.
123   static PyLocation *getDefaultLocation();
124 
125   PyMlirContext *getContext();
126   PyInsertionPoint *getInsertionPoint();
127   PyLocation *getLocation();
128   FrameKind getFrameKind() { return frameKind; }
129 
130   /// Stack management.
131   static PyThreadContextEntry *getTopOfStack();
132   static pybind11::object pushContext(PyMlirContext &context);
133   static void popContext(PyMlirContext &context);
134   static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint);
135   static void popInsertionPoint(PyInsertionPoint &insertionPoint);
136   static pybind11::object pushLocation(PyLocation &location);
137   static void popLocation(PyLocation &location);
138 
139   /// Gets the thread local stack.
140   static std::vector<PyThreadContextEntry> &getStack();
141 
142 private:
143   static void push(FrameKind frameKind, pybind11::object context,
144                    pybind11::object insertionPoint, pybind11::object location);
145 
146   /// An object reference to the PyContext.
147   pybind11::object context;
148   /// An object reference to the current insertion point.
149   pybind11::object insertionPoint;
150   /// An object reference to the current location.
151   pybind11::object location;
152   // The kind of push that was performed.
153   FrameKind frameKind;
154 };
155 
156 /// Wrapper around MlirContext.
157 using PyMlirContextRef = PyObjectRef<PyMlirContext>;
158 class PyMlirContext {
159 public:
160   PyMlirContext() = delete;
161   PyMlirContext(const PyMlirContext &) = delete;
162   PyMlirContext(PyMlirContext &&) = delete;
163 
164   /// For the case of a python __init__ (py::init) method, pybind11 is quite
165   /// strict about needing to return a pointer that is not yet associated to
166   /// an py::object. Since the forContext() method acts like a pool, possibly
167   /// returning a recycled context, it does not satisfy this need. The usual
168   /// way in python to accomplish such a thing is to override __new__, but
169   /// that is also not supported by pybind11. Instead, we use this entry
170   /// point which always constructs a fresh context (which cannot alias an
171   /// existing one because it is fresh).
172   static PyMlirContext *createNewContextForInit();
173 
174   /// Returns a context reference for the singleton PyMlirContext wrapper for
175   /// the given context.
176   static PyMlirContextRef forContext(MlirContext context);
177   ~PyMlirContext();
178 
179   /// Accesses the underlying MlirContext.
180   MlirContext get() { return context; }
181 
182   /// Gets a strong reference to this context, which will ensure it is kept
183   /// alive for the life of the reference.
184   PyMlirContextRef getRef() {
185     return PyMlirContextRef(this, pybind11::cast(this));
186   }
187 
188   /// Gets a capsule wrapping the void* within the MlirContext.
189   pybind11::object getCapsule();
190 
191   /// Creates a PyMlirContext from the MlirContext wrapped by a capsule.
192   /// Note that PyMlirContext instances are uniqued, so the returned object
193   /// may be a pre-existing object. Ownership of the underlying MlirContext
194   /// is taken by calling this function.
195   static pybind11::object createFromCapsule(pybind11::object capsule);
196 
197   /// Gets the count of live context objects. Used for testing.
198   static size_t getLiveCount();
199 
200   /// Gets the count of live operations associated with this context.
201   /// Used for testing.
202   size_t getLiveOperationCount();
203 
204   /// Gets the count of live modules associated with this context.
205   /// Used for testing.
206   size_t getLiveModuleCount();
207 
208   /// Enter and exit the context manager.
209   pybind11::object contextEnter();
210   void contextExit(const pybind11::object &excType,
211                    const pybind11::object &excVal,
212                    const pybind11::object &excTb);
213 
214   /// Attaches a Python callback as a diagnostic handler, returning a
215   /// registration object (internally a PyDiagnosticHandler).
216   pybind11::object attachDiagnosticHandler(pybind11::object callback);
217 
218 private:
219   PyMlirContext(MlirContext context);
220   // Interns the mapping of live MlirContext::ptr to PyMlirContext instances,
221   // preserving the relationship that an MlirContext maps to a single
222   // PyMlirContext wrapper. This could be replaced in the future with an
223   // extension mechanism on the MlirContext for stashing user pointers.
224   // Note that this holds a handle, which does not imply ownership.
225   // Mappings will be removed when the context is destructed.
226   using LiveContextMap = llvm::DenseMap<void *, PyMlirContext *>;
227   static LiveContextMap &getLiveContexts();
228 
229   // Interns all live modules associated with this context. Modules tracked
230   // in this map are valid. When a module is invalidated, it is removed
231   // from this map, and while it still exists as an instance, any
232   // attempt to access it will raise an error.
233   using LiveModuleMap =
234       llvm::DenseMap<const void *, std::pair<pybind11::handle, PyModule *>>;
235   LiveModuleMap liveModules;
236 
237   // Interns all live operations associated with this context. Operations
238   // tracked in this map are valid. When an operation is invalidated, it is
239   // removed from this map, and while it still exists as an instance, any
240   // attempt to access it will raise an error.
241   using LiveOperationMap =
242       llvm::DenseMap<void *, std::pair<pybind11::handle, PyOperation *>>;
243   LiveOperationMap liveOperations;
244 
245   MlirContext context;
246   friend class PyModule;
247   friend class PyOperation;
248 };
249 
250 /// Used in function arguments when None should resolve to the current context
251 /// manager set instance.
252 class DefaultingPyMlirContext
253     : public Defaulting<DefaultingPyMlirContext, PyMlirContext> {
254 public:
255   using Defaulting::Defaulting;
256   static constexpr const char kTypeDescription[] = "mlir.ir.Context";
257   static PyMlirContext &resolve();
258 };
259 
260 /// Base class for all objects that directly or indirectly depend on an
261 /// MlirContext. The lifetime of the context will extend at least to the
262 /// lifetime of these instances.
263 /// Immutable objects that depend on a context extend this directly.
264 class BaseContextObject {
265 public:
266   BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) {
267     assert(this->contextRef &&
268            "context object constructed with null context ref");
269   }
270 
271   /// Accesses the context reference.
272   PyMlirContextRef &getContext() { return contextRef; }
273 
274 private:
275   PyMlirContextRef contextRef;
276 };
277 
278 /// Python class mirroring the C MlirDiagnostic struct. Note that these structs
279 /// are only valid for the duration of a diagnostic callback and attempting
280 /// to access them outside of that will raise an exception. This applies to
281 /// nested diagnostics (in the notes) as well.
282 class PyDiagnostic {
283 public:
284   PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {}
285   void invalidate();
286   bool isValid() { return valid; }
287   MlirDiagnosticSeverity getSeverity();
288   PyLocation getLocation();
289   pybind11::str getMessage();
290   pybind11::tuple getNotes();
291 
292 private:
293   MlirDiagnostic diagnostic;
294 
295   void checkValid();
296   /// If notes have been materialized from the diagnostic, then this will
297   /// be populated with the corresponding objects (all castable to
298   /// PyDiagnostic).
299   llvm::Optional<pybind11::tuple> materializedNotes;
300   bool valid = true;
301 };
302 
303 /// Represents a diagnostic handler attached to the context. The handler's
304 /// callback will be invoked with PyDiagnostic instances until the detach()
305 /// method is called or the context is destroyed. A diagnostic handler can be
306 /// the subject of a `with` block, which will detach it when the block exits.
307 ///
308 /// Since diagnostic handlers can call back into Python code which can do
309 /// unsafe things (i.e. recursively emitting diagnostics, raising exceptions,
310 /// etc), this is generally not deemed to be a great user-level API. Users
311 /// should generally use some form of DiagnosticCollector. If the handler raises
312 /// any exceptions, they will just be emitted to stderr and dropped.
313 ///
314 /// The unique usage of this class means that its lifetime management is
315 /// different from most other parts of the API. Instances are always created
316 /// in an attached state and can transition to a detached state by either:
317 ///   a) The context being destroyed and unregistering all handlers.
318 ///   b) An explicit call to detach().
319 /// The object may remain live from a Python perspective for an arbitrary time
320 /// after detachment, but there is nothing the user can do with it (since there
321 /// is no way to attach an existing handler object).
322 class PyDiagnosticHandler {
323 public:
324   PyDiagnosticHandler(MlirContext context, pybind11::object callback);
325   ~PyDiagnosticHandler();
326 
327   bool isAttached() { return registeredID.hasValue(); }
328   bool getHadError() { return hadError; }
329 
330   /// Detaches the handler. Does nothing if not attached.
331   void detach();
332 
333   pybind11::object contextEnter() { return pybind11::cast(this); }
334   void contextExit(const pybind11::object &excType,
335                    const pybind11::object &excVal,
336                    const pybind11::object &excTb) {
337     detach();
338   }
339 
340 private:
341   MlirContext context;
342   pybind11::object callback;
343   llvm::Optional<MlirDiagnosticHandlerID> registeredID;
344   bool hadError = false;
345   friend class PyMlirContext;
346 };
347 
348 /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in
349 /// order to differentiate it from the `Dialect` base class which is extended by
350 /// plugins which extend dialect functionality through extension python code.
351 /// This should be seen as the "low-level" object and `Dialect` as the
352 /// high-level, user facing object.
353 class PyDialectDescriptor : public BaseContextObject {
354 public:
355   PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect)
356       : BaseContextObject(std::move(contextRef)), dialect(dialect) {}
357 
358   MlirDialect get() { return dialect; }
359 
360 private:
361   MlirDialect dialect;
362 };
363 
364 /// User-level object for accessing dialects with dotted syntax such as:
365 ///   ctx.dialect.std
366 class PyDialects : public BaseContextObject {
367 public:
368   PyDialects(PyMlirContextRef contextRef)
369       : BaseContextObject(std::move(contextRef)) {}
370 
371   MlirDialect getDialectForKey(const std::string &key, bool attrError);
372 };
373 
374 /// User-level dialect object. For dialects that have a registered extension,
375 /// this will be the base class of the extension dialect type. For un-extended,
376 /// objects of this type will be returned directly.
377 class PyDialect {
378 public:
379   PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {}
380 
381   pybind11::object getDescriptor() { return descriptor; }
382 
383 private:
384   pybind11::object descriptor;
385 };
386 
387 /// Wrapper around an MlirLocation.
388 class PyLocation : public BaseContextObject {
389 public:
390   PyLocation(PyMlirContextRef contextRef, MlirLocation loc)
391       : BaseContextObject(std::move(contextRef)), loc(loc) {}
392 
393   operator MlirLocation() const { return loc; }
394   MlirLocation get() const { return loc; }
395 
396   /// Enter and exit the context manager.
397   pybind11::object contextEnter();
398   void contextExit(const pybind11::object &excType,
399                    const pybind11::object &excVal,
400                    const pybind11::object &excTb);
401 
402   /// Gets a capsule wrapping the void* within the MlirLocation.
403   pybind11::object getCapsule();
404 
405   /// Creates a PyLocation from the MlirLocation wrapped by a capsule.
406   /// Note that PyLocation instances are uniqued, so the returned object
407   /// may be a pre-existing object. Ownership of the underlying MlirLocation
408   /// is taken by calling this function.
409   static PyLocation createFromCapsule(pybind11::object capsule);
410 
411 private:
412   MlirLocation loc;
413 };
414 
415 /// Used in function arguments when None should resolve to the current context
416 /// manager set instance.
417 class DefaultingPyLocation
418     : public Defaulting<DefaultingPyLocation, PyLocation> {
419 public:
420   using Defaulting::Defaulting;
421   static constexpr const char kTypeDescription[] = "mlir.ir.Location";
422   static PyLocation &resolve();
423 
424   operator MlirLocation() const { return *get(); }
425 };
426 
427 /// Wrapper around MlirModule.
428 /// This is the top-level, user-owned object that contains regions/ops/blocks.
429 class PyModule;
430 using PyModuleRef = PyObjectRef<PyModule>;
431 class PyModule : public BaseContextObject {
432 public:
433   /// Returns a PyModule reference for the given MlirModule. This may return
434   /// a pre-existing or new object.
435   static PyModuleRef forModule(MlirModule module);
436   PyModule(PyModule &) = delete;
437   PyModule(PyMlirContext &&) = delete;
438   ~PyModule();
439 
440   /// Gets the backing MlirModule.
441   MlirModule get() { return module; }
442 
443   /// Gets a strong reference to this module.
444   PyModuleRef getRef() {
445     return PyModuleRef(this,
446                        pybind11::reinterpret_borrow<pybind11::object>(handle));
447   }
448 
449   /// Gets a capsule wrapping the void* within the MlirModule.
450   /// Note that the module does not (yet) provide a corresponding factory for
451   /// constructing from a capsule as that would require uniquing PyModule
452   /// instances, which is not currently done.
453   pybind11::object getCapsule();
454 
455   /// Creates a PyModule from the MlirModule wrapped by a capsule.
456   /// Note that PyModule instances are uniqued, so the returned object
457   /// may be a pre-existing object. Ownership of the underlying MlirModule
458   /// is taken by calling this function.
459   static pybind11::object createFromCapsule(pybind11::object capsule);
460 
461 private:
462   PyModule(PyMlirContextRef contextRef, MlirModule module);
463   MlirModule module;
464   pybind11::handle handle;
465 };
466 
467 /// Base class for PyOperation and PyOpView which exposes the primary, user
468 /// visible methods for manipulating it.
469 class PyOperationBase {
470 public:
471   virtual ~PyOperationBase() = default;
472   /// Implements the bound 'print' method and helps with others.
473   void print(pybind11::object fileObject, bool binary,
474              llvm::Optional<int64_t> largeElementsLimit, bool enableDebugInfo,
475              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
476              bool assumeVerified);
477   pybind11::object getAsm(bool binary,
478                           llvm::Optional<int64_t> largeElementsLimit,
479                           bool enableDebugInfo, bool prettyDebugInfo,
480                           bool printGenericOpForm, bool useLocalScope,
481                           bool assumeVerified);
482 
483   /// Moves the operation before or after the other operation.
484   void moveAfter(PyOperationBase &other);
485   void moveBefore(PyOperationBase &other);
486 
487   /// Each must provide access to the raw Operation.
488   virtual PyOperation &getOperation() = 0;
489 };
490 
491 /// Wrapper around PyOperation.
492 /// Operations exist in either an attached (dependent) or detached (top-level)
493 /// state. In the detached state (as on creation), an operation is owned by
494 /// the creator and its lifetime extends either until its reference count
495 /// drops to zero or it is attached to a parent, at which point its lifetime
496 /// is bounded by its top-level parent reference.
497 class PyOperation;
498 using PyOperationRef = PyObjectRef<PyOperation>;
499 class PyOperation : public PyOperationBase, public BaseContextObject {
500 public:
501   ~PyOperation() override;
502   PyOperation &getOperation() override { return *this; }
503 
504   /// Returns a PyOperation for the given MlirOperation, optionally associating
505   /// it with a parentKeepAlive.
506   static PyOperationRef
507   forOperation(PyMlirContextRef contextRef, MlirOperation operation,
508                pybind11::object parentKeepAlive = pybind11::object());
509 
510   /// Creates a detached operation. The operation must not be associated with
511   /// any existing live operation.
512   static PyOperationRef
513   createDetached(PyMlirContextRef contextRef, MlirOperation operation,
514                  pybind11::object parentKeepAlive = pybind11::object());
515 
516   /// Detaches the operation from its parent block and updates its state
517   /// accordingly.
518   void detachFromParent() {
519     mlirOperationRemoveFromParent(getOperation());
520     setDetached();
521     parentKeepAlive = pybind11::object();
522   }
523 
524   /// Gets the backing operation.
525   operator MlirOperation() const { return get(); }
526   MlirOperation get() const {
527     checkValid();
528     return operation;
529   }
530 
531   PyOperationRef getRef() {
532     return PyOperationRef(
533         this, pybind11::reinterpret_borrow<pybind11::object>(handle));
534   }
535 
536   bool isAttached() { return attached; }
537   void setAttached(const pybind11::object &parent = pybind11::object()) {
538     assert(!attached && "operation already attached");
539     attached = true;
540   }
541   void setDetached() {
542     assert(attached && "operation already detached");
543     attached = false;
544   }
545   void checkValid() const;
546 
547   /// Gets the owning block or raises an exception if the operation has no
548   /// owning block.
549   PyBlock getBlock();
550 
551   /// Gets the parent operation or raises an exception if the operation has
552   /// no parent.
553   llvm::Optional<PyOperationRef> getParentOperation();
554 
555   /// Gets a capsule wrapping the void* within the MlirOperation.
556   pybind11::object getCapsule();
557 
558   /// Creates a PyOperation from the MlirOperation wrapped by a capsule.
559   /// Ownership of the underlying MlirOperation is taken by calling this
560   /// function.
561   static pybind11::object createFromCapsule(pybind11::object capsule);
562 
563   /// Creates an operation. See corresponding python docstring.
564   static pybind11::object
565   create(const std::string &name, llvm::Optional<std::vector<PyType *>> results,
566          llvm::Optional<std::vector<PyValue *>> operands,
567          llvm::Optional<pybind11::dict> attributes,
568          llvm::Optional<std::vector<PyBlock *>> successors, int regions,
569          DefaultingPyLocation location, const pybind11::object &ip);
570 
571   /// Creates an OpView suitable for this operation.
572   pybind11::object createOpView();
573 
574   /// Erases the underlying MlirOperation, removes its pointer from the
575   /// parent context's live operations map, and sets the valid bit false.
576   void erase();
577 
578   /// Clones this operation.
579   pybind11::object clone(const pybind11::object &ip);
580 
581 private:
582   PyOperation(PyMlirContextRef contextRef, MlirOperation operation);
583   static PyOperationRef createInstance(PyMlirContextRef contextRef,
584                                        MlirOperation operation,
585                                        pybind11::object parentKeepAlive);
586 
587   MlirOperation operation;
588   pybind11::handle handle;
589   // Keeps the parent alive, regardless of whether it is an Operation or
590   // Module.
591   // TODO: As implemented, this facility is only sufficient for modeling the
592   // trivial module parent back-reference. Generalize this to also account for
593   // transitions from detached to attached and address TODOs in the
594   // ir_operation.py regarding testing corresponding lifetime guarantees.
595   pybind11::object parentKeepAlive;
596   bool attached = true;
597   bool valid = true;
598 
599   friend class PyOperationBase;
600   friend class PySymbolTable;
601 };
602 
603 /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for
604 /// providing more instance-specific accessors and serve as the base class for
605 /// custom ODS-style operation classes. Since this class is subclass on the
606 /// python side, it must present an __init__ method that operates in pure
607 /// python types.
608 class PyOpView : public PyOperationBase {
609 public:
610   PyOpView(const pybind11::object &operationObject);
611   PyOperation &getOperation() override { return operation; }
612 
613   static pybind11::object createRawSubclass(const pybind11::object &userClass);
614 
615   pybind11::object getOperationObject() { return operationObject; }
616 
617   static pybind11::object
618   buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList,
619                pybind11::list operandList,
620                llvm::Optional<pybind11::dict> attributes,
621                llvm::Optional<std::vector<PyBlock *>> successors,
622                llvm::Optional<int> regions, DefaultingPyLocation location,
623                const pybind11::object &maybeIp);
624 
625 private:
626   PyOperation &operation;           // For efficient, cast-free access from C++
627   pybind11::object operationObject; // Holds the reference.
628 };
629 
630 /// Wrapper around an MlirRegion.
631 /// Regions are managed completely by their containing operation. Unlike the
632 /// C++ API, the python API does not support detached regions.
633 class PyRegion {
634 public:
635   PyRegion(PyOperationRef parentOperation, MlirRegion region)
636       : parentOperation(std::move(parentOperation)), region(region) {
637     assert(!mlirRegionIsNull(region) && "python region cannot be null");
638   }
639   operator MlirRegion() const { return region; }
640 
641   MlirRegion get() { return region; }
642   PyOperationRef &getParentOperation() { return parentOperation; }
643 
644   void checkValid() { return parentOperation->checkValid(); }
645 
646 private:
647   PyOperationRef parentOperation;
648   MlirRegion region;
649 };
650 
651 /// Wrapper around an MlirBlock.
652 /// Blocks are managed completely by their containing operation. Unlike the
653 /// C++ API, the python API does not support detached blocks.
654 class PyBlock {
655 public:
656   PyBlock(PyOperationRef parentOperation, MlirBlock block)
657       : parentOperation(std::move(parentOperation)), block(block) {
658     assert(!mlirBlockIsNull(block) && "python block cannot be null");
659   }
660 
661   MlirBlock get() { return block; }
662   PyOperationRef &getParentOperation() { return parentOperation; }
663 
664   void checkValid() { return parentOperation->checkValid(); }
665 
666 private:
667   PyOperationRef parentOperation;
668   MlirBlock block;
669 };
670 
671 /// An insertion point maintains a pointer to a Block and a reference operation.
672 /// Calls to insert() will insert a new operation before the
673 /// reference operation. If the reference operation is null, then appends to
674 /// the end of the block.
675 class PyInsertionPoint {
676 public:
677   /// Creates an insertion point positioned after the last operation in the
678   /// block, but still inside the block.
679   PyInsertionPoint(PyBlock &block);
680   /// Creates an insertion point positioned before a reference operation.
681   PyInsertionPoint(PyOperationBase &beforeOperationBase);
682 
683   /// Shortcut to create an insertion point at the beginning of the block.
684   static PyInsertionPoint atBlockBegin(PyBlock &block);
685   /// Shortcut to create an insertion point before the block terminator.
686   static PyInsertionPoint atBlockTerminator(PyBlock &block);
687 
688   /// Inserts an operation.
689   void insert(PyOperationBase &operationBase);
690 
691   /// Enter and exit the context manager.
692   pybind11::object contextEnter();
693   void contextExit(const pybind11::object &excType,
694                    const pybind11::object &excVal,
695                    const pybind11::object &excTb);
696 
697   PyBlock &getBlock() { return block; }
698 
699 private:
700   // Trampoline constructor that avoids null initializing members while
701   // looking up parents.
702   PyInsertionPoint(PyBlock block, llvm::Optional<PyOperationRef> refOperation)
703       : refOperation(std::move(refOperation)), block(std::move(block)) {}
704 
705   llvm::Optional<PyOperationRef> refOperation;
706   PyBlock block;
707 };
708 /// Wrapper around the generic MlirType.
709 /// The lifetime of a type is bound by the PyContext that created it.
710 class PyType : public BaseContextObject {
711 public:
712   PyType(PyMlirContextRef contextRef, MlirType type)
713       : BaseContextObject(std::move(contextRef)), type(type) {}
714   bool operator==(const PyType &other);
715   operator MlirType() const { return type; }
716   MlirType get() const { return type; }
717 
718   /// Gets a capsule wrapping the void* within the MlirType.
719   pybind11::object getCapsule();
720 
721   /// Creates a PyType from the MlirType wrapped by a capsule.
722   /// Note that PyType instances are uniqued, so the returned object
723   /// may be a pre-existing object. Ownership of the underlying MlirType
724   /// is taken by calling this function.
725   static PyType createFromCapsule(pybind11::object capsule);
726 
727 private:
728   MlirType type;
729 };
730 
731 /// CRTP base classes for Python types that subclass Type and should be
732 /// castable from it (i.e. via something like IntegerType(t)).
733 /// By default, type class hierarchies are one level deep (i.e. a
734 /// concrete type class extends PyType); however, intermediate python-visible
735 /// base classes can be modeled by specifying a BaseTy.
736 template <typename DerivedTy, typename BaseTy = PyType>
737 class PyConcreteType : public BaseTy {
738 public:
739   // Derived classes must define statics for:
740   //   IsAFunctionTy isaFunction
741   //   const char *pyClassName
742   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
743   using IsAFunctionTy = bool (*)(MlirType);
744 
745   PyConcreteType() = default;
746   PyConcreteType(PyMlirContextRef contextRef, MlirType t)
747       : BaseTy(std::move(contextRef), t) {}
748   PyConcreteType(PyType &orig)
749       : PyConcreteType(orig.getContext(), castFrom(orig)) {}
750 
751   static MlirType castFrom(PyType &orig) {
752     if (!DerivedTy::isaFunction(orig)) {
753       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
754       throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") +
755                                              DerivedTy::pyClassName +
756                                              " (from " + origRepr + ")");
757     }
758     return orig;
759   }
760 
761   static void bind(pybind11::module &m) {
762     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local());
763     cls.def(pybind11::init<PyType &>(), pybind11::keep_alive<0, 1>(),
764             pybind11::arg("cast_from_type"));
765     cls.def_static(
766         "isinstance",
767         [](PyType &otherType) -> bool {
768           return DerivedTy::isaFunction(otherType);
769         },
770         pybind11::arg("other"));
771     DerivedTy::bindDerived(cls);
772   }
773 
774   /// Implemented by derived classes to add methods to the Python subclass.
775   static void bindDerived(ClassTy &m) {}
776 };
777 
778 /// Wrapper around the generic MlirAttribute.
779 /// The lifetime of a type is bound by the PyContext that created it.
780 class PyAttribute : public BaseContextObject {
781 public:
782   PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
783       : BaseContextObject(std::move(contextRef)), attr(attr) {}
784   bool operator==(const PyAttribute &other);
785   operator MlirAttribute() const { return attr; }
786   MlirAttribute get() const { return attr; }
787 
788   /// Gets a capsule wrapping the void* within the MlirAttribute.
789   pybind11::object getCapsule();
790 
791   /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule.
792   /// Note that PyAttribute instances are uniqued, so the returned object
793   /// may be a pre-existing object. Ownership of the underlying MlirAttribute
794   /// is taken by calling this function.
795   static PyAttribute createFromCapsule(pybind11::object capsule);
796 
797 private:
798   MlirAttribute attr;
799 };
800 
801 /// Represents a Python MlirNamedAttr, carrying an optional owned name.
802 /// TODO: Refactor this and the C-API to be based on an Identifier owned
803 /// by the context so as to avoid ownership issues here.
804 class PyNamedAttribute {
805 public:
806   /// Constructs a PyNamedAttr that retains an owned name. This should be
807   /// used in any code that originates an MlirNamedAttribute from a python
808   /// string.
809   /// The lifetime of the PyNamedAttr must extend to the lifetime of the
810   /// passed attribute.
811   PyNamedAttribute(MlirAttribute attr, std::string ownedName);
812 
813   MlirNamedAttribute namedAttr;
814 
815 private:
816   // Since the MlirNamedAttr contains an internal pointer to the actual
817   // memory of the owned string, it must be heap allocated to remain valid.
818   // Otherwise, strings that fit within the small object optimization threshold
819   // will have their memory address change as the containing object is moved,
820   // resulting in an invalid aliased pointer.
821   std::unique_ptr<std::string> ownedName;
822 };
823 
824 /// CRTP base classes for Python attributes that subclass Attribute and should
825 /// be castable from it (i.e. via something like StringAttr(attr)).
826 /// By default, attribute class hierarchies are one level deep (i.e. a
827 /// concrete attribute class extends PyAttribute); however, intermediate
828 /// python-visible base classes can be modeled by specifying a BaseTy.
829 template <typename DerivedTy, typename BaseTy = PyAttribute>
830 class PyConcreteAttribute : public BaseTy {
831 public:
832   // Derived classes must define statics for:
833   //   IsAFunctionTy isaFunction
834   //   const char *pyClassName
835   using ClassTy = pybind11::class_<DerivedTy, BaseTy>;
836   using IsAFunctionTy = bool (*)(MlirAttribute);
837 
838   PyConcreteAttribute() = default;
839   PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
840       : BaseTy(std::move(contextRef), attr) {}
841   PyConcreteAttribute(PyAttribute &orig)
842       : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
843 
844   static MlirAttribute castFrom(PyAttribute &orig) {
845     if (!DerivedTy::isaFunction(orig)) {
846       auto origRepr = pybind11::repr(pybind11::cast(orig)).cast<std::string>();
847       throw SetPyError(PyExc_ValueError,
848                        llvm::Twine("Cannot cast attribute to ") +
849                            DerivedTy::pyClassName + " (from " + origRepr + ")");
850     }
851     return orig;
852   }
853 
854   static void bind(pybind11::module &m) {
855     auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(),
856                        pybind11::module_local());
857     cls.def(pybind11::init<PyAttribute &>(), pybind11::keep_alive<0, 1>(),
858             pybind11::arg("cast_from_attr"));
859     cls.def_static(
860         "isinstance",
861         [](PyAttribute &otherAttr) -> bool {
862           return DerivedTy::isaFunction(otherAttr);
863         },
864         pybind11::arg("other"));
865     cls.def_property_readonly("type", [](PyAttribute &attr) {
866       return PyType(attr.getContext(), mlirAttributeGetType(attr));
867     });
868     DerivedTy::bindDerived(cls);
869   }
870 
871   /// Implemented by derived classes to add methods to the Python subclass.
872   static void bindDerived(ClassTy &m) {}
873 };
874 
875 /// Wrapper around the generic MlirValue.
876 /// Values are managed completely by the operation that resulted in their
877 /// definition. For op result value, this is the operation that defines the
878 /// value. For block argument values, this is the operation that contains the
879 /// block to which the value is an argument (blocks cannot be detached in Python
880 /// bindings so such operation always exists).
881 class PyValue {
882 public:
883   PyValue(PyOperationRef parentOperation, MlirValue value)
884       : parentOperation(std::move(parentOperation)), value(value) {}
885   operator MlirValue() const { return value; }
886 
887   MlirValue get() { return value; }
888   PyOperationRef &getParentOperation() { return parentOperation; }
889 
890   void checkValid() { return parentOperation->checkValid(); }
891 
892   /// Gets a capsule wrapping the void* within the MlirValue.
893   pybind11::object getCapsule();
894 
895   /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
896   /// the underlying MlirValue is still tied to the owning operation.
897   static PyValue createFromCapsule(pybind11::object capsule);
898 
899 private:
900   PyOperationRef parentOperation;
901   MlirValue value;
902 };
903 
904 /// Wrapper around MlirAffineExpr. Affine expressions are owned by the context.
905 class PyAffineExpr : public BaseContextObject {
906 public:
907   PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr)
908       : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {}
909   bool operator==(const PyAffineExpr &other);
910   operator MlirAffineExpr() const { return affineExpr; }
911   MlirAffineExpr get() const { return affineExpr; }
912 
913   /// Gets a capsule wrapping the void* within the MlirAffineExpr.
914   pybind11::object getCapsule();
915 
916   /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule.
917   /// Note that PyAffineExpr instances are uniqued, so the returned object
918   /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr
919   /// is taken by calling this function.
920   static PyAffineExpr createFromCapsule(pybind11::object capsule);
921 
922   PyAffineExpr add(const PyAffineExpr &other) const;
923   PyAffineExpr mul(const PyAffineExpr &other) const;
924   PyAffineExpr floorDiv(const PyAffineExpr &other) const;
925   PyAffineExpr ceilDiv(const PyAffineExpr &other) const;
926   PyAffineExpr mod(const PyAffineExpr &other) const;
927 
928 private:
929   MlirAffineExpr affineExpr;
930 };
931 
932 class PyAffineMap : public BaseContextObject {
933 public:
934   PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap)
935       : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {}
936   bool operator==(const PyAffineMap &other);
937   operator MlirAffineMap() const { return affineMap; }
938   MlirAffineMap get() const { return affineMap; }
939 
940   /// Gets a capsule wrapping the void* within the MlirAffineMap.
941   pybind11::object getCapsule();
942 
943   /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule.
944   /// Note that PyAffineMap instances are uniqued, so the returned object
945   /// may be a pre-existing object. Ownership of the underlying MlirAffineMap
946   /// is taken by calling this function.
947   static PyAffineMap createFromCapsule(pybind11::object capsule);
948 
949 private:
950   MlirAffineMap affineMap;
951 };
952 
953 class PyIntegerSet : public BaseContextObject {
954 public:
955   PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet)
956       : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {}
957   bool operator==(const PyIntegerSet &other);
958   operator MlirIntegerSet() const { return integerSet; }
959   MlirIntegerSet get() const { return integerSet; }
960 
961   /// Gets a capsule wrapping the void* within the MlirIntegerSet.
962   pybind11::object getCapsule();
963 
964   /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule.
965   /// Note that PyIntegerSet instances may be uniqued, so the returned object
966   /// may be a pre-existing object. Integer sets are owned by the context.
967   static PyIntegerSet createFromCapsule(pybind11::object capsule);
968 
969 private:
970   MlirIntegerSet integerSet;
971 };
972 
973 /// Bindings for MLIR symbol tables.
974 class PySymbolTable {
975 public:
976   /// Constructs a symbol table for the given operation.
977   explicit PySymbolTable(PyOperationBase &operation);
978 
979   /// Destroys the symbol table.
980   ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); }
981 
982   /// Returns the symbol (opview) with the given name, throws if there is no
983   /// such symbol in the table.
984   pybind11::object dunderGetItem(const std::string &name);
985 
986   /// Removes the given operation from the symbol table and erases it.
987   void erase(PyOperationBase &symbol);
988 
989   /// Removes the operation with the given name from the symbol table and erases
990   /// it, throws if there is no such symbol in the table.
991   void dunderDel(const std::string &name);
992 
993   /// Inserts the given operation into the symbol table. The operation must have
994   /// the symbol trait.
995   PyAttribute insert(PyOperationBase &symbol);
996 
997   /// Gets and sets the name of a symbol op.
998   static PyAttribute getSymbolName(PyOperationBase &symbol);
999   static void setSymbolName(PyOperationBase &symbol, const std::string &name);
1000 
1001   /// Gets and sets the visibility of a symbol op.
1002   static PyAttribute getVisibility(PyOperationBase &symbol);
1003   static void setVisibility(PyOperationBase &symbol,
1004                             const std::string &visibility);
1005 
1006   /// Replaces all symbol uses within an operation. See the API
1007   /// mlirSymbolTableReplaceAllSymbolUses for all caveats.
1008   static void replaceAllSymbolUses(const std::string &oldSymbol,
1009                                    const std::string &newSymbol,
1010                                    PyOperationBase &from);
1011 
1012   /// Walks all symbol tables under and including 'from'.
1013   static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible,
1014                                pybind11::object callback);
1015 
1016   /// Casts the bindings class into the C API structure.
1017   operator MlirSymbolTable() { return symbolTable; }
1018 
1019 private:
1020   PyOperationRef operation;
1021   MlirSymbolTable symbolTable;
1022 };
1023 
1024 void populateIRAffine(pybind11::module &m);
1025 void populateIRAttributes(pybind11::module &m);
1026 void populateIRCore(pybind11::module &m);
1027 void populateIRInterfaces(pybind11::module &m);
1028 void populateIRTypes(pybind11::module &m);
1029 
1030 } // namespace python
1031 } // namespace mlir
1032 
1033 namespace pybind11 {
1034 namespace detail {
1035 
1036 template <>
1037 struct type_caster<mlir::python::DefaultingPyMlirContext>
1038     : MlirDefaultingCaster<mlir::python::DefaultingPyMlirContext> {};
1039 template <>
1040 struct type_caster<mlir::python::DefaultingPyLocation>
1041     : MlirDefaultingCaster<mlir::python::DefaultingPyLocation> {};
1042 
1043 } // namespace detail
1044 } // namespace pybind11
1045 
1046 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H
1047