1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 12 #include <pybind11/pybind11.h> 13 #include <pybind11/stl.h> 14 15 #include "llvm/ADT/Optional.h" 16 #include "llvm/ADT/Twine.h" 17 18 namespace mlir { 19 namespace python { 20 21 // Sets a python error, ready to be thrown to return control back to the 22 // python runtime. 23 // Correct usage: 24 // throw SetPyError(PyExc_ValueError, "Foobar'd"); 25 pybind11::error_already_set SetPyError(PyObject *excClass, 26 const llvm::Twine &message); 27 28 /// CRTP template for special wrapper types that are allowed to be passed in as 29 /// 'None' function arguments and can be resolved by some global mechanic if 30 /// so. Such types will raise an error if this global resolution fails, and 31 /// it is actually illegal for them to ever be unresolved. From a user 32 /// perspective, they behave like a smart ptr to the underlying type (i.e. 33 /// 'get' method and operator-> overloaded). 34 /// 35 /// Derived types must provide a method, which is called when an environmental 36 /// resolution is required. It must raise an exception if resolution fails: 37 /// static ReferrentTy &resolve() 38 /// 39 /// They must also provide a parameter description that will be used in 40 /// error messages about mismatched types: 41 /// static constexpr const char kTypeDescription[] = "<Description>"; 42 43 template <typename DerivedTy, typename T> 44 class Defaulting { 45 public: 46 using ReferrentTy = T; 47 /// Type casters require the type to be default constructible, but using 48 /// such an instance is illegal. 49 Defaulting() = default; 50 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 51 52 ReferrentTy *get() { return referrent; } 53 ReferrentTy *operator->() { return referrent; } 54 55 private: 56 ReferrentTy *referrent = nullptr; 57 }; 58 59 } // namespace python 60 } // namespace mlir 61 62 namespace pybind11 { 63 namespace detail { 64 65 template <typename DefaultingTy> 66 struct MlirDefaultingCaster { 67 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); 68 69 bool load(pybind11::handle src, bool) { 70 if (src.is_none()) { 71 // Note that we do want an exception to propagate from here as it will be 72 // the most informative. 73 value = DefaultingTy{DefaultingTy::resolve()}; 74 return true; 75 } 76 77 // Unlike many casters that chain, these casters are expected to always 78 // succeed, so instead of doing an isinstance check followed by a cast, 79 // just cast in one step and handle the exception. Returning false (vs 80 // letting the exception propagate) causes higher level signature parsing 81 // code to produce nice error messages (other than "Cannot cast..."). 82 try { 83 value = DefaultingTy{ 84 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; 85 return true; 86 } catch (std::exception &e) { 87 return false; 88 } 89 } 90 91 static handle cast(DefaultingTy src, return_value_policy policy, 92 handle parent) { 93 return pybind11::cast(src, policy); 94 } 95 }; 96 97 template <typename T> 98 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 99 } // namespace detail 100 } // namespace pybind11 101 102 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 103