1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
11 
12 #include <pybind11/pybind11.h>
13 #include <pybind11/stl.h>
14 
15 #include "llvm/ADT/Optional.h"
16 #include "llvm/ADT/Twine.h"
17 
18 namespace mlir {
19 namespace python {
20 
21 // Sets a python error, ready to be thrown to return control back to the
22 // python runtime.
23 // Correct usage:
24 //   throw SetPyError(PyExc_ValueError, "Foobar'd");
25 pybind11::error_already_set SetPyError(PyObject *excClass,
26                                        const llvm::Twine &message);
27 
28 /// CRTP template for special wrapper types that are allowed to be passed in as
29 /// 'None' function arguments and can be resolved by some global mechanic if
30 /// so. Such types will raise an error if this global resolution fails, and
31 /// it is actually illegal for them to ever be unresolved. From a user
32 /// perspective, they behave like a smart ptr to the underlying type (i.e.
33 /// 'get' method and operator-> overloaded).
34 ///
35 /// Derived types must provide a method, which is called when an environmental
36 /// resolution is required. It must raise an exception if resolution fails:
37 ///   static ReferrentTy &resolve()
38 ///
39 /// They must also provide a parameter description that will be used in
40 /// error messages about mismatched types:
41 ///   static constexpr const char kTypeDescription[] = "<Description>";
42 
43 template <typename DerivedTy, typename T>
44 class Defaulting {
45 public:
46   using ReferrentTy = T;
47   /// Type casters require the type to be default constructible, but using
48   /// such an instance is illegal.
49   Defaulting() = default;
50   Defaulting(ReferrentTy &referrent) : referrent(&referrent) {}
51 
52   ReferrentTy *get() { return referrent; }
53   ReferrentTy *operator->() { return referrent; }
54 
55 private:
56   ReferrentTy *referrent = nullptr;
57 };
58 
59 } // namespace python
60 } // namespace mlir
61 
62 namespace pybind11 {
63 namespace detail {
64 
65 template <typename DefaultingTy>
66 struct MlirDefaultingCaster {
67   PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription));
68 
69   bool load(pybind11::handle src, bool) {
70     if (src.is_none()) {
71       // Note that we do want an exception to propagate from here as it will be
72       // the most informative.
73       value = DefaultingTy{DefaultingTy::resolve()};
74       return true;
75     }
76 
77     // Unlike many casters that chain, these casters are expected to always
78     // succeed, so instead of doing an isinstance check followed by a cast,
79     // just cast in one step and handle the exception. Returning false (vs
80     // letting the exception propagate) causes higher level signature parsing
81     // code to produce nice error messages (other than "Cannot cast...").
82     try {
83       value = DefaultingTy{
84           pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)};
85       return true;
86     } catch (std::exception &e) {
87       return false;
88     }
89   }
90 
91   static handle cast(DefaultingTy src, return_value_policy policy,
92                      handle parent) {
93     return pybind11::cast(src, policy);
94   }
95 };
96 
97 template <typename T>
98 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {};
99 } // namespace detail
100 } // namespace pybind11
101 
102 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H
103