1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 12 #include "mlir-c/Support.h" 13 #include "llvm/ADT/Optional.h" 14 #include "llvm/ADT/Twine.h" 15 16 #include <pybind11/pybind11.h> 17 #include <pybind11/stl.h> 18 19 20 namespace mlir { 21 namespace python { 22 23 // Sets a python error, ready to be thrown to return control back to the 24 // python runtime. 25 // Correct usage: 26 // throw SetPyError(PyExc_ValueError, "Foobar'd"); 27 pybind11::error_already_set SetPyError(PyObject *excClass, 28 const llvm::Twine &message); 29 30 /// CRTP template for special wrapper types that are allowed to be passed in as 31 /// 'None' function arguments and can be resolved by some global mechanic if 32 /// so. Such types will raise an error if this global resolution fails, and 33 /// it is actually illegal for them to ever be unresolved. From a user 34 /// perspective, they behave like a smart ptr to the underlying type (i.e. 35 /// 'get' method and operator-> overloaded). 36 /// 37 /// Derived types must provide a method, which is called when an environmental 38 /// resolution is required. It must raise an exception if resolution fails: 39 /// static ReferrentTy &resolve() 40 /// 41 /// They must also provide a parameter description that will be used in 42 /// error messages about mismatched types: 43 /// static constexpr const char kTypeDescription[] = "<Description>"; 44 45 template <typename DerivedTy, typename T> 46 class Defaulting { 47 public: 48 using ReferrentTy = T; 49 /// Type casters require the type to be default constructible, but using 50 /// such an instance is illegal. 51 Defaulting() = default; 52 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 53 54 ReferrentTy *get() { return referrent; } 55 ReferrentTy *operator->() { return referrent; } 56 57 private: 58 ReferrentTy *referrent = nullptr; 59 }; 60 61 } // namespace python 62 } // namespace mlir 63 64 namespace pybind11 { 65 namespace detail { 66 67 template <typename DefaultingTy> 68 struct MlirDefaultingCaster { 69 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); 70 71 bool load(pybind11::handle src, bool) { 72 if (src.is_none()) { 73 // Note that we do want an exception to propagate from here as it will be 74 // the most informative. 75 value = DefaultingTy{DefaultingTy::resolve()}; 76 return true; 77 } 78 79 // Unlike many casters that chain, these casters are expected to always 80 // succeed, so instead of doing an isinstance check followed by a cast, 81 // just cast in one step and handle the exception. Returning false (vs 82 // letting the exception propagate) causes higher level signature parsing 83 // code to produce nice error messages (other than "Cannot cast..."). 84 try { 85 value = DefaultingTy{ 86 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; 87 return true; 88 } catch (std::exception &e) { 89 return false; 90 } 91 } 92 93 static handle cast(DefaultingTy src, return_value_policy policy, 94 handle parent) { 95 return pybind11::cast(src, policy); 96 } 97 }; 98 99 template <typename T> 100 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 101 } // namespace detail 102 } // namespace pybind11 103 104 //------------------------------------------------------------------------------ 105 // Conversion utilities. 106 //------------------------------------------------------------------------------ 107 108 namespace mlir { 109 110 /// Accumulates into a python string from a method that accepts an 111 /// MlirStringCallback. 112 struct PyPrintAccumulator { 113 pybind11::list parts; 114 115 void *getUserData() { return this; } 116 117 MlirStringCallback getCallback() { 118 return [](const char *part, intptr_t size, void *userData) { 119 PyPrintAccumulator *printAccum = 120 static_cast<PyPrintAccumulator *>(userData); 121 pybind11::str pyPart(part, size); // Decodes as UTF-8 by default. 122 printAccum->parts.append(std::move(pyPart)); 123 }; 124 } 125 126 pybind11::str join() { 127 pybind11::str delim("", 0); 128 return delim.attr("join")(parts); 129 } 130 }; 131 132 /// Accumulates int a python file-like object, either writing text (default) 133 /// or binary. 134 class PyFileAccumulator { 135 public: 136 PyFileAccumulator(pybind11::object fileObject, bool binary) 137 : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 138 139 void *getUserData() { return this; } 140 141 MlirStringCallback getCallback() { 142 return [](const char *part, intptr_t size, void *userData) { 143 pybind11::gil_scoped_acquire(); 144 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 145 if (accum->binary) { 146 // Note: Still has to copy and not avoidable with this API. 147 pybind11::bytes pyBytes(part, size); 148 accum->pyWriteFunction(pyBytes); 149 } else { 150 pybind11::str pyStr(part, size); // Decodes as UTF-8 by default. 151 accum->pyWriteFunction(pyStr); 152 } 153 }; 154 } 155 156 private: 157 pybind11::object pyWriteFunction; 158 bool binary; 159 }; 160 161 /// Accumulates into a python string from a method that is expected to make 162 /// one (no more, no less) call to the callback (asserts internally on 163 /// violation). 164 struct PySinglePartStringAccumulator { 165 void *getUserData() { return this; } 166 167 MlirStringCallback getCallback() { 168 return [](const char *part, intptr_t size, void *userData) { 169 PySinglePartStringAccumulator *accum = 170 static_cast<PySinglePartStringAccumulator *>(userData); 171 assert(!accum->invoked && 172 "PySinglePartStringAccumulator called back multiple times"); 173 accum->invoked = true; 174 accum->value = pybind11::str(part, size); 175 }; 176 } 177 178 pybind11::str takeValue() { 179 assert(invoked && "PySinglePartStringAccumulator not called back"); 180 return std::move(value); 181 } 182 183 private: 184 pybind11::str value; 185 bool invoked = false; 186 }; 187 188 /// A CRTP base class for pseudo-containers willing to support Python-type 189 /// slicing access on top of indexed access. Calling ::bind on this class 190 /// will define `__len__` as well as `__getitem__` with integer and slice 191 /// arguments. 192 /// 193 /// This is intended for pseudo-containers that can refer to arbitrary slices of 194 /// underlying storage indexed by a single integer. Indexing those with an 195 /// integer produces an instance of ElementTy. Indexing those with a slice 196 /// produces a new instance of Derived, which can be sliced further. 197 /// 198 /// A derived class must provide the following: 199 /// - a `static const char *pyClassName ` field containing the name of the 200 /// Python class to bind; 201 /// - an instance method `intptr_t getNumElements()` that returns the number 202 /// of elements in the backing container (NOT that of the slice); 203 /// - an instance method `ElementTy getElement(intptr_t)` that returns a 204 /// single element at the given index. 205 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that 206 /// constructs a new instance of the derived pseudo-container with the 207 /// given slice parameters (to be forwarded to the Sliceable constructor). 208 /// 209 /// A derived class may additionally define: 210 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods 211 /// the python class. 212 template <typename Derived, typename ElementTy> 213 class Sliceable { 214 protected: 215 using ClassTy = pybind11::class_<Derived>; 216 217 public: 218 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) 219 : startIndex(startIndex), length(length), step(step) { 220 assert(length >= 0 && "expected non-negative slice length"); 221 } 222 223 /// Returns the length of the slice. 224 intptr_t dunderLen() const { return length; } 225 226 /// Returns the element at the given slice index. Supports negative indices 227 /// by taking elements in inverse order. Throws if the index is out of bounds. 228 ElementTy dunderGetItem(intptr_t index) { 229 // Negative indices mean we count from the end. 230 if (index < 0) 231 index = length + index; 232 if (index < 0 || index >= length) { 233 throw python::SetPyError(PyExc_IndexError, 234 "attempt to access out of bounds"); 235 } 236 237 // Compute the linear index given the current slice properties. 238 int linearIndex = index * step + startIndex; 239 assert(linearIndex >= 0 && 240 linearIndex < static_cast<Derived *>(this)->getNumElements() && 241 "linear index out of bounds, the slice is ill-formed"); 242 return static_cast<Derived *>(this)->getElement(linearIndex); 243 } 244 245 /// Returns a new instance of the pseudo-container restricted to the given 246 /// slice. 247 Derived dunderGetItemSlice(pybind11::slice slice) { 248 ssize_t start, stop, extraStep, sliceLength; 249 if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { 250 throw python::SetPyError(PyExc_IndexError, 251 "attempt to access out of bounds"); 252 } 253 return static_cast<Derived *>(this)->slice(startIndex + start * step, 254 sliceLength, step * extraStep); 255 } 256 257 /// Binds the indexing and length methods in the Python class. 258 static void bind(pybind11::module &m) { 259 auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName) 260 .def("__len__", &Sliceable::dunderLen) 261 .def("__getitem__", &Sliceable::dunderGetItem) 262 .def("__getitem__", &Sliceable::dunderGetItemSlice); 263 Derived::bindDerived(clazz); 264 } 265 266 /// Hook for derived classes willing to bind more methods. 267 static void bindDerived(ClassTy &) {} 268 269 private: 270 intptr_t startIndex; 271 intptr_t length; 272 intptr_t step; 273 }; 274 275 } // namespace mlir 276 277 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 278