1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 12 #include "mlir-c/Support.h" 13 #include "llvm/ADT/Optional.h" 14 #include "llvm/ADT/Twine.h" 15 16 #include <pybind11/pybind11.h> 17 #include <pybind11/stl.h> 18 19 namespace mlir { 20 namespace python { 21 22 // Sets a python error, ready to be thrown to return control back to the 23 // python runtime. 24 // Correct usage: 25 // throw SetPyError(PyExc_ValueError, "Foobar'd"); 26 pybind11::error_already_set SetPyError(PyObject *excClass, 27 const llvm::Twine &message); 28 29 /// CRTP template for special wrapper types that are allowed to be passed in as 30 /// 'None' function arguments and can be resolved by some global mechanic if 31 /// so. Such types will raise an error if this global resolution fails, and 32 /// it is actually illegal for them to ever be unresolved. From a user 33 /// perspective, they behave like a smart ptr to the underlying type (i.e. 34 /// 'get' method and operator-> overloaded). 35 /// 36 /// Derived types must provide a method, which is called when an environmental 37 /// resolution is required. It must raise an exception if resolution fails: 38 /// static ReferrentTy &resolve() 39 /// 40 /// They must also provide a parameter description that will be used in 41 /// error messages about mismatched types: 42 /// static constexpr const char kTypeDescription[] = "<Description>"; 43 44 template <typename DerivedTy, typename T> 45 class Defaulting { 46 public: 47 using ReferrentTy = T; 48 /// Type casters require the type to be default constructible, but using 49 /// such an instance is illegal. 50 Defaulting() = default; 51 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 52 53 ReferrentTy *get() const { return referrent; } 54 ReferrentTy *operator->() { return referrent; } 55 56 private: 57 ReferrentTy *referrent = nullptr; 58 }; 59 60 } // namespace python 61 } // namespace mlir 62 63 namespace pybind11 { 64 namespace detail { 65 66 template <typename DefaultingTy> 67 struct MlirDefaultingCaster { 68 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); 69 70 bool load(pybind11::handle src, bool) { 71 if (src.is_none()) { 72 // Note that we do want an exception to propagate from here as it will be 73 // the most informative. 74 value = DefaultingTy{DefaultingTy::resolve()}; 75 return true; 76 } 77 78 // Unlike many casters that chain, these casters are expected to always 79 // succeed, so instead of doing an isinstance check followed by a cast, 80 // just cast in one step and handle the exception. Returning false (vs 81 // letting the exception propagate) causes higher level signature parsing 82 // code to produce nice error messages (other than "Cannot cast..."). 83 try { 84 value = DefaultingTy{ 85 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; 86 return true; 87 } catch (std::exception &) { 88 return false; 89 } 90 } 91 92 static handle cast(DefaultingTy src, return_value_policy policy, 93 handle parent) { 94 return pybind11::cast(src, policy); 95 } 96 }; 97 98 template <typename T> 99 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 100 } // namespace detail 101 } // namespace pybind11 102 103 //------------------------------------------------------------------------------ 104 // Conversion utilities. 105 //------------------------------------------------------------------------------ 106 107 namespace mlir { 108 109 /// Accumulates into a python string from a method that accepts an 110 /// MlirStringCallback. 111 struct PyPrintAccumulator { 112 pybind11::list parts; 113 114 void *getUserData() { return this; } 115 116 MlirStringCallback getCallback() { 117 return [](MlirStringRef part, void *userData) { 118 PyPrintAccumulator *printAccum = 119 static_cast<PyPrintAccumulator *>(userData); 120 pybind11::str pyPart(part.data, 121 part.length); // Decodes as UTF-8 by default. 122 printAccum->parts.append(std::move(pyPart)); 123 }; 124 } 125 126 pybind11::str join() { 127 pybind11::str delim("", 0); 128 return delim.attr("join")(parts); 129 } 130 }; 131 132 /// Accumulates int a python file-like object, either writing text (default) 133 /// or binary. 134 class PyFileAccumulator { 135 public: 136 PyFileAccumulator(const pybind11::object &fileObject, bool binary) 137 : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 138 139 void *getUserData() { return this; } 140 141 MlirStringCallback getCallback() { 142 return [](MlirStringRef part, void *userData) { 143 pybind11::gil_scoped_acquire acquire; 144 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 145 if (accum->binary) { 146 // Note: Still has to copy and not avoidable with this API. 147 pybind11::bytes pyBytes(part.data, part.length); 148 accum->pyWriteFunction(pyBytes); 149 } else { 150 pybind11::str pyStr(part.data, 151 part.length); // Decodes as UTF-8 by default. 152 accum->pyWriteFunction(pyStr); 153 } 154 }; 155 } 156 157 private: 158 pybind11::object pyWriteFunction; 159 bool binary; 160 }; 161 162 /// Accumulates into a python string from a method that is expected to make 163 /// one (no more, no less) call to the callback (asserts internally on 164 /// violation). 165 struct PySinglePartStringAccumulator { 166 void *getUserData() { return this; } 167 168 MlirStringCallback getCallback() { 169 return [](MlirStringRef part, void *userData) { 170 PySinglePartStringAccumulator *accum = 171 static_cast<PySinglePartStringAccumulator *>(userData); 172 assert(!accum->invoked && 173 "PySinglePartStringAccumulator called back multiple times"); 174 accum->invoked = true; 175 accum->value = pybind11::str(part.data, part.length); 176 }; 177 } 178 179 pybind11::str takeValue() { 180 assert(invoked && "PySinglePartStringAccumulator not called back"); 181 return std::move(value); 182 } 183 184 private: 185 pybind11::str value; 186 bool invoked = false; 187 }; 188 189 /// A CRTP base class for pseudo-containers willing to support Python-type 190 /// slicing access on top of indexed access. Calling ::bind on this class 191 /// will define `__len__` as well as `__getitem__` with integer and slice 192 /// arguments. 193 /// 194 /// This is intended for pseudo-containers that can refer to arbitrary slices of 195 /// underlying storage indexed by a single integer. Indexing those with an 196 /// integer produces an instance of ElementTy. Indexing those with a slice 197 /// produces a new instance of Derived, which can be sliced further. 198 /// 199 /// A derived class must provide the following: 200 /// - a `static const char *pyClassName ` field containing the name of the 201 /// Python class to bind; 202 /// - an instance method `intptr_t getNumElements()` that returns the number 203 /// of elements in the backing container (NOT that of the slice); 204 /// - an instance method `ElementTy getElement(intptr_t)` that returns a 205 /// single element at the given index. 206 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that 207 /// constructs a new instance of the derived pseudo-container with the 208 /// given slice parameters (to be forwarded to the Sliceable constructor). 209 /// 210 /// The getNumElements() and getElement(intptr_t) callbacks must not throw. 211 /// 212 /// A derived class may additionally define: 213 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods 214 /// the python class. 215 template <typename Derived, typename ElementTy> 216 class Sliceable { 217 protected: 218 using ClassTy = pybind11::class_<Derived>; 219 220 // Transforms `index` into a legal value to access the underlying sequence. 221 // Returns <0 on failure. 222 intptr_t wrapIndex(intptr_t index) { 223 if (index < 0) 224 index = length + index; 225 if (index < 0 || index >= length) 226 return -1; 227 return index; 228 } 229 230 /// Returns the element at the given slice index. Supports negative indices 231 /// by taking elements in inverse order. Returns a nullptr object if out 232 /// of bounds. 233 pybind11::object getItem(intptr_t index) { 234 // Negative indices mean we count from the end. 235 index = wrapIndex(index); 236 if (index < 0) { 237 PyErr_SetString(PyExc_IndexError, "index out of range"); 238 return {}; 239 } 240 241 // Compute the linear index given the current slice properties. 242 int linearIndex = index * step + startIndex; 243 assert(linearIndex >= 0 && 244 linearIndex < static_cast<Derived *>(this)->getNumElements() && 245 "linear index out of bounds, the slice is ill-formed"); 246 return pybind11::cast( 247 static_cast<Derived *>(this)->getElement(linearIndex)); 248 } 249 250 /// Returns a new instance of the pseudo-container restricted to the given 251 /// slice. Returns a nullptr object on failure. 252 pybind11::object getItemSlice(PyObject *slice) { 253 ssize_t start, stop, extraStep, sliceLength; 254 if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, 255 &sliceLength) != 0) { 256 PyErr_SetString(PyExc_IndexError, "index out of range"); 257 return {}; 258 } 259 return pybind11::cast(static_cast<Derived *>(this)->slice( 260 startIndex + start * step, sliceLength, step * extraStep)); 261 } 262 263 public: 264 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) 265 : startIndex(startIndex), length(length), step(step) { 266 assert(length >= 0 && "expected non-negative slice length"); 267 } 268 269 /// Returns a new vector (mapped to Python list) containing elements from two 270 /// slices. The new vector is necessary because slices may not be contiguous 271 /// or even come from the same original sequence. 272 std::vector<ElementTy> dunderAdd(Derived &other) { 273 std::vector<ElementTy> elements; 274 elements.reserve(length + other.length); 275 for (intptr_t i = 0; i < length; ++i) { 276 elements.push_back(static_cast<Derived *>(this)->getElement(i)); 277 } 278 for (intptr_t i = 0; i < other.length; ++i) { 279 elements.push_back(static_cast<Derived *>(this)->getElement(i)); 280 } 281 return elements; 282 } 283 284 /// Binds the indexing and length methods in the Python class. 285 static void bind(pybind11::module &m) { 286 auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName, 287 pybind11::module_local()) 288 .def("__add__", &Sliceable::dunderAdd); 289 Derived::bindDerived(clazz); 290 291 // Manually implement the sequence protocol via the C API. We do this 292 // because it is approx 4x faster than via pybind11, largely because that 293 // formulation requires a C++ exception to be thrown to detect end of 294 // sequence. 295 // Since we are in a C-context, any C++ exception that happens here 296 // will terminate the program. There is nothing in this implementation 297 // that should throw in a non-terminal way, so we forgo further 298 // exception marshalling. 299 // See: https://github.com/pybind/pybind11/issues/2842 300 auto heap_type = reinterpret_cast<PyHeapTypeObject *>(clazz.ptr()); 301 assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && 302 "must be heap type"); 303 heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { 304 auto self = pybind11::cast<Derived *>(rawSelf); 305 return self->length; 306 }; 307 // sq_item is called as part of the sequence protocol for iteration, 308 // list construction, etc. 309 heap_type->as_sequence.sq_item = 310 +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { 311 auto self = pybind11::cast<Derived *>(rawSelf); 312 return self->getItem(index).release().ptr(); 313 }; 314 // mp_subscript is used for both slices and integer lookups. 315 heap_type->as_mapping.mp_subscript = 316 +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { 317 auto self = pybind11::cast<Derived *>(rawSelf); 318 Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); 319 if (!PyErr_Occurred()) { 320 // Integer indexing. 321 return self->getItem(index).release().ptr(); 322 } 323 PyErr_Clear(); 324 325 // Assume slice-based indexing. 326 if (PySlice_Check(rawSubscript)) { 327 return self->getItemSlice(rawSubscript).release().ptr(); 328 } 329 330 PyErr_SetString(PyExc_ValueError, "expected integer or slice"); 331 return nullptr; 332 }; 333 } 334 335 /// Hook for derived classes willing to bind more methods. 336 static void bindDerived(ClassTy &) {} 337 338 private: 339 intptr_t startIndex; 340 intptr_t length; 341 intptr_t step; 342 }; 343 344 } // namespace mlir 345 346 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 347