1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 12 #include "mlir-c/Support.h" 13 #include "llvm/ADT/Optional.h" 14 #include "llvm/ADT/Twine.h" 15 16 #include <pybind11/pybind11.h> 17 #include <pybind11/stl.h> 18 19 20 namespace mlir { 21 namespace python { 22 23 // Sets a python error, ready to be thrown to return control back to the 24 // python runtime. 25 // Correct usage: 26 // throw SetPyError(PyExc_ValueError, "Foobar'd"); 27 pybind11::error_already_set SetPyError(PyObject *excClass, 28 const llvm::Twine &message); 29 30 /// CRTP template for special wrapper types that are allowed to be passed in as 31 /// 'None' function arguments and can be resolved by some global mechanic if 32 /// so. Such types will raise an error if this global resolution fails, and 33 /// it is actually illegal for them to ever be unresolved. From a user 34 /// perspective, they behave like a smart ptr to the underlying type (i.e. 35 /// 'get' method and operator-> overloaded). 36 /// 37 /// Derived types must provide a method, which is called when an environmental 38 /// resolution is required. It must raise an exception if resolution fails: 39 /// static ReferrentTy &resolve() 40 /// 41 /// They must also provide a parameter description that will be used in 42 /// error messages about mismatched types: 43 /// static constexpr const char kTypeDescription[] = "<Description>"; 44 45 template <typename DerivedTy, typename T> 46 class Defaulting { 47 public: 48 using ReferrentTy = T; 49 /// Type casters require the type to be default constructible, but using 50 /// such an instance is illegal. 51 Defaulting() = default; 52 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 53 54 ReferrentTy *get() { return referrent; } 55 ReferrentTy *operator->() { return referrent; } 56 57 private: 58 ReferrentTy *referrent = nullptr; 59 }; 60 61 } // namespace python 62 } // namespace mlir 63 64 namespace pybind11 { 65 namespace detail { 66 67 template <typename DefaultingTy> 68 struct MlirDefaultingCaster { 69 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); 70 71 bool load(pybind11::handle src, bool) { 72 if (src.is_none()) { 73 // Note that we do want an exception to propagate from here as it will be 74 // the most informative. 75 value = DefaultingTy{DefaultingTy::resolve()}; 76 return true; 77 } 78 79 // Unlike many casters that chain, these casters are expected to always 80 // succeed, so instead of doing an isinstance check followed by a cast, 81 // just cast in one step and handle the exception. Returning false (vs 82 // letting the exception propagate) causes higher level signature parsing 83 // code to produce nice error messages (other than "Cannot cast..."). 84 try { 85 value = DefaultingTy{ 86 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; 87 return true; 88 } catch (std::exception &e) { 89 return false; 90 } 91 } 92 93 static handle cast(DefaultingTy src, return_value_policy policy, 94 handle parent) { 95 return pybind11::cast(src, policy); 96 } 97 }; 98 99 template <typename T> 100 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 101 } // namespace detail 102 } // namespace pybind11 103 104 //------------------------------------------------------------------------------ 105 // Conversion utilities. 106 //------------------------------------------------------------------------------ 107 108 namespace mlir { 109 110 /// Accumulates into a python string from a method that accepts an 111 /// MlirStringCallback. 112 struct PyPrintAccumulator { 113 pybind11::list parts; 114 115 void *getUserData() { return this; } 116 117 MlirStringCallback getCallback() { 118 return [](const char *part, intptr_t size, void *userData) { 119 PyPrintAccumulator *printAccum = 120 static_cast<PyPrintAccumulator *>(userData); 121 pybind11::str pyPart(part, size); // Decodes as UTF-8 by default. 122 printAccum->parts.append(std::move(pyPart)); 123 }; 124 } 125 126 pybind11::str join() { 127 pybind11::str delim("", 0); 128 return delim.attr("join")(parts); 129 } 130 }; 131 132 /// Accumulates int a python file-like object, either writing text (default) 133 /// or binary. 134 class PyFileAccumulator { 135 public: 136 PyFileAccumulator(pybind11::object fileObject, bool binary) 137 : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 138 139 void *getUserData() { return this; } 140 141 MlirStringCallback getCallback() { 142 return [](const char *part, intptr_t size, void *userData) { 143 pybind11::gil_scoped_acquire(); 144 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 145 if (accum->binary) { 146 // Note: Still has to copy and not avoidable with this API. 147 pybind11::bytes pyBytes(part, size); 148 accum->pyWriteFunction(pyBytes); 149 } else { 150 pybind11::str pyStr(part, size); // Decodes as UTF-8 by default. 151 accum->pyWriteFunction(pyStr); 152 } 153 }; 154 } 155 156 private: 157 pybind11::object pyWriteFunction; 158 bool binary; 159 }; 160 161 /// Accumulates into a python string from a method that is expected to make 162 /// one (no more, no less) call to the callback (asserts internally on 163 /// violation). 164 struct PySinglePartStringAccumulator { 165 void *getUserData() { return this; } 166 167 MlirStringCallback getCallback() { 168 return [](const char *part, intptr_t size, void *userData) { 169 PySinglePartStringAccumulator *accum = 170 static_cast<PySinglePartStringAccumulator *>(userData); 171 assert(!accum->invoked && 172 "PySinglePartStringAccumulator called back multiple times"); 173 accum->invoked = true; 174 accum->value = pybind11::str(part, size); 175 }; 176 } 177 178 pybind11::str takeValue() { 179 assert(invoked && "PySinglePartStringAccumulator not called back"); 180 return std::move(value); 181 } 182 183 private: 184 pybind11::str value; 185 bool invoked = false; 186 }; 187 188 } // namespace mlir 189 190 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 191