1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "executor_address.h" 20 #include "simple_packed_serialization.h" 21 #include <type_traits> 22 23 namespace __orc_rt { 24 25 /// C++ wrapper function result: Same as CWrapperFunctionResult but 26 /// auto-releases memory. 27 class WrapperFunctionResult { 28 public: 29 /// Create a default WrapperFunctionResult. 30 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } 31 32 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 33 /// instance takes ownership of the result object and will automatically 34 /// call dispose on the result upon destruction. 35 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} 36 37 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 38 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 39 40 WrapperFunctionResult(WrapperFunctionResult &&Other) { 41 __orc_rt_CWrapperFunctionResultInit(&R); 42 std::swap(R, Other.R); 43 } 44 45 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 46 __orc_rt_CWrapperFunctionResult Tmp; 47 __orc_rt_CWrapperFunctionResultInit(&Tmp); 48 std::swap(Tmp, Other.R); 49 std::swap(R, Tmp); 50 return *this; 51 } 52 53 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } 54 55 /// Relinquish ownership of and return the 56 /// __orc_rt_CWrapperFunctionResult. 57 __orc_rt_CWrapperFunctionResult release() { 58 __orc_rt_CWrapperFunctionResult Tmp; 59 __orc_rt_CWrapperFunctionResultInit(&Tmp); 60 std::swap(R, Tmp); 61 return Tmp; 62 } 63 64 /// Get a pointer to the data contained in this instance. 65 char *data() { return __orc_rt_CWrapperFunctionResultData(&R); } 66 67 /// Returns the size of the data contained in this instance. 68 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } 69 70 /// Returns true if this value is equivalent to a default-constructed 71 /// WrapperFunctionResult. 72 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } 73 74 /// Create a WrapperFunctionResult with the given size and return a pointer 75 /// to the underlying memory. 76 static WrapperFunctionResult allocate(size_t Size) { 77 WrapperFunctionResult R; 78 R.R = __orc_rt_CWrapperFunctionResultAllocate(Size); 79 return R; 80 } 81 82 /// Copy from the given char range. 83 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 84 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 85 } 86 87 /// Copy from the given null-terminated string (includes the null-terminator). 88 static WrapperFunctionResult copyFrom(const char *Source) { 89 return __orc_rt_CreateCWrapperFunctionResultFromString(Source); 90 } 91 92 /// Copy from the given std::string (includes the null terminator). 93 static WrapperFunctionResult copyFrom(const std::string &Source) { 94 return copyFrom(Source.c_str()); 95 } 96 97 /// Create an out-of-band error by copying the given string. 98 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 99 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 100 } 101 102 /// Create an out-of-band error by copying the given string. 103 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 104 return createOutOfBandError(Msg.c_str()); 105 } 106 107 /// If this value is an out-of-band error then this returns the error message, 108 /// otherwise returns nullptr. 109 const char *getOutOfBandError() const { 110 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 111 } 112 113 private: 114 __orc_rt_CWrapperFunctionResult R; 115 }; 116 117 namespace detail { 118 119 template <typename SPSArgListT, typename... ArgTs> 120 WrapperFunctionResult 121 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { 122 auto Result = WrapperFunctionResult::allocate(SPSArgListT::size(Args...)); 123 SPSOutputBuffer OB(Result.data(), Result.size()); 124 if (!SPSArgListT::serialize(OB, Args...)) 125 return WrapperFunctionResult::createOutOfBandError( 126 "Error serializing arguments to blob in call"); 127 return Result; 128 } 129 130 template <typename RetT> class WrapperFunctionHandlerCaller { 131 public: 132 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 133 static decltype(auto) call(HandlerT &&H, ArgTupleT &Args, 134 std::index_sequence<I...>) { 135 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 136 } 137 }; 138 139 template <> class WrapperFunctionHandlerCaller<void> { 140 public: 141 template <typename HandlerT, typename ArgTupleT, std::size_t... I> 142 static SPSEmpty call(HandlerT &&H, ArgTupleT &Args, 143 std::index_sequence<I...>) { 144 std::forward<HandlerT>(H)(std::get<I>(Args)...); 145 return SPSEmpty(); 146 } 147 }; 148 149 template <typename WrapperFunctionImplT, 150 template <typename> class ResultSerializer, typename... SPSTagTs> 151 class WrapperFunctionHandlerHelper 152 : public WrapperFunctionHandlerHelper< 153 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 154 ResultSerializer, SPSTagTs...> {}; 155 156 template <typename RetT, typename... ArgTs, 157 template <typename> class ResultSerializer, typename... SPSTagTs> 158 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 159 SPSTagTs...> { 160 public: 161 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 162 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 163 164 template <typename HandlerT> 165 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 166 size_t ArgSize) { 167 ArgTuple Args; 168 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 169 return WrapperFunctionResult::createOutOfBandError( 170 "Could not deserialize arguments for wrapper function call"); 171 172 auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call( 173 std::forward<HandlerT>(H), Args, ArgIndices{}); 174 175 return ResultSerializer<decltype(HandlerResult)>::serialize( 176 std::move(HandlerResult)); 177 } 178 179 private: 180 template <std::size_t... I> 181 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 182 std::index_sequence<I...>) { 183 SPSInputBuffer IB(ArgData, ArgSize); 184 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 185 } 186 }; 187 188 // Map function pointers to function types. 189 template <typename RetT, typename... ArgTs, 190 template <typename> class ResultSerializer, typename... SPSTagTs> 191 class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer, 192 SPSTagTs...> 193 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 194 SPSTagTs...> {}; 195 196 // Map non-const member function types to function types. 197 template <typename ClassT, typename RetT, typename... ArgTs, 198 template <typename> class ResultSerializer, typename... SPSTagTs> 199 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 200 SPSTagTs...> 201 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 202 SPSTagTs...> {}; 203 204 // Map const member function types to function types. 205 template <typename ClassT, typename RetT, typename... ArgTs, 206 template <typename> class ResultSerializer, typename... SPSTagTs> 207 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 208 ResultSerializer, SPSTagTs...> 209 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 210 SPSTagTs...> {}; 211 212 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 213 public: 214 static WrapperFunctionResult serialize(RetT Result) { 215 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 216 Result); 217 } 218 }; 219 220 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 221 public: 222 static WrapperFunctionResult serialize(Error Err) { 223 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 224 toSPSSerializable(std::move(Err))); 225 } 226 }; 227 228 template <typename SPSRetTagT, typename T> 229 class ResultSerializer<SPSRetTagT, Expected<T>> { 230 public: 231 static WrapperFunctionResult serialize(Expected<T> E) { 232 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 233 toSPSSerializable(std::move(E))); 234 } 235 }; 236 237 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 238 public: 239 static void makeSafe(RetT &Result) {} 240 241 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 242 SPSInputBuffer IB(ArgData, ArgSize); 243 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 244 return make_error<StringError>( 245 "Error deserializing return value from blob in call"); 246 return Error::success(); 247 } 248 }; 249 250 template <> class ResultDeserializer<SPSError, Error> { 251 public: 252 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 253 254 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 255 SPSInputBuffer IB(ArgData, ArgSize); 256 SPSSerializableError BSE; 257 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 258 return make_error<StringError>( 259 "Error deserializing return value from blob in call"); 260 Err = fromSPSSerializable(std::move(BSE)); 261 return Error::success(); 262 } 263 }; 264 265 template <typename SPSTagT, typename T> 266 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 267 public: 268 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 269 270 static Error deserialize(Expected<T> &E, const char *ArgData, 271 size_t ArgSize) { 272 SPSInputBuffer IB(ArgData, ArgSize); 273 SPSSerializableExpected<T> BSE; 274 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 275 return make_error<StringError>( 276 "Error deserializing return value from blob in call"); 277 E = fromSPSSerializable(std::move(BSE)); 278 return Error::success(); 279 } 280 }; 281 282 } // end namespace detail 283 284 template <typename SPSSignature> class WrapperFunction; 285 286 template <typename SPSRetTagT, typename... SPSTagTs> 287 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 288 private: 289 template <typename RetT> 290 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 291 292 public: 293 template <typename RetT, typename... ArgTs> 294 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 295 296 // RetT might be an Error or Expected value. Set the checked flag now: 297 // we don't want the user to have to check the unused result if this 298 // operation fails. 299 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 300 301 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 302 return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set"); 303 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch)) 304 return make_error<StringError>("__orc_rt_jit_dispatch not set"); 305 306 auto ArgBuffer = 307 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 308 Args...); 309 if (const char *ErrMsg = ArgBuffer.getOutOfBandError()) 310 return make_error<StringError>(ErrMsg); 311 312 WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch( 313 &__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size()); 314 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 315 return make_error<StringError>(ErrMsg); 316 317 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 318 Result, ResultBuffer.data(), ResultBuffer.size()); 319 } 320 321 template <typename HandlerT> 322 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 323 HandlerT &&Handler) { 324 using WFHH = 325 detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>, 326 ResultSerializer, SPSTagTs...>; 327 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 328 } 329 330 private: 331 template <typename T> static const T &makeSerializable(const T &Value) { 332 return Value; 333 } 334 335 static detail::SPSSerializableError makeSerializable(Error Err) { 336 return detail::toSPSSerializable(std::move(Err)); 337 } 338 339 template <typename T> 340 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 341 return detail::toSPSSerializable(std::move(E)); 342 } 343 }; 344 345 template <typename... SPSTagTs> 346 class WrapperFunction<void(SPSTagTs...)> 347 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 348 public: 349 template <typename... ArgTs> 350 static Error call(const void *FnTag, const ArgTs &...Args) { 351 SPSEmpty BE; 352 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 353 } 354 355 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 356 }; 357 358 /// A function object that takes an ExecutorAddr as its first argument, 359 /// casts that address to a ClassT*, then calls the given method on that 360 /// pointer passing in the remaining function arguments. This utility 361 /// removes some of the boilerplate from writing wrappers for method calls. 362 /// 363 /// @code{.cpp} 364 /// class MyClass { 365 /// public: 366 /// void myMethod(uint32_t, bool) { ... } 367 /// }; 368 /// 369 /// // SPS Method signature -- note MyClass object address as first argument. 370 /// using SPSMyMethodWrapperSignature = 371 /// SPSTuple<SPSExecutorAddr, uint32_t, bool>; 372 /// 373 /// WrapperFunctionResult 374 /// myMethodCallWrapper(const char *ArgData, size_t ArgSize) { 375 /// return WrapperFunction<SPSMyMethodWrapperSignature>::handle( 376 /// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod)); 377 /// } 378 /// @endcode 379 /// 380 template <typename RetT, typename ClassT, typename... ArgTs> 381 class MethodWrapperHandler { 382 public: 383 using MethodT = RetT (ClassT::*)(ArgTs...); 384 MethodWrapperHandler(MethodT M) : M(M) {} 385 RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) { 386 return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...); 387 } 388 389 private: 390 MethodT M; 391 }; 392 393 /// Create a MethodWrapperHandler object from the given method pointer. 394 template <typename RetT, typename ClassT, typename... ArgTs> 395 MethodWrapperHandler<RetT, ClassT, ArgTs...> 396 makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) { 397 return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method); 398 } 399 400 } // end namespace __orc_rt 401 402 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 403