1 //===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file is a part of the ORC runtime support library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H 14 #define ORC_RT_WRAPPER_FUNCTION_UTILS_H 15 16 #include "c_api.h" 17 #include "common.h" 18 #include "error.h" 19 #include "simple_packed_serialization.h" 20 #include <type_traits> 21 22 namespace __orc_rt { 23 24 /// C++ wrapper function result: Same as CWrapperFunctionResult but 25 /// auto-releases memory. 26 class WrapperFunctionResult { 27 public: 28 /// Create a default WrapperFunctionResult. 29 WrapperFunctionResult() { __orc_rt_CWrapperFunctionResultInit(&R); } 30 31 /// Create a WrapperFunctionResult from a CWrapperFunctionResult. This 32 /// instance takes ownership of the result object and will automatically 33 /// call dispose on the result upon destruction. 34 WrapperFunctionResult(__orc_rt_CWrapperFunctionResult R) : R(R) {} 35 36 WrapperFunctionResult(const WrapperFunctionResult &) = delete; 37 WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete; 38 39 WrapperFunctionResult(WrapperFunctionResult &&Other) { 40 __orc_rt_CWrapperFunctionResultInit(&R); 41 std::swap(R, Other.R); 42 } 43 44 WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) { 45 __orc_rt_CWrapperFunctionResult Tmp; 46 __orc_rt_CWrapperFunctionResultInit(&Tmp); 47 std::swap(Tmp, Other.R); 48 std::swap(R, Tmp); 49 return *this; 50 } 51 52 ~WrapperFunctionResult() { __orc_rt_DisposeCWrapperFunctionResult(&R); } 53 54 /// Relinquish ownership of and return the 55 /// __orc_rt_CWrapperFunctionResult. 56 __orc_rt_CWrapperFunctionResult release() { 57 __orc_rt_CWrapperFunctionResult Tmp; 58 __orc_rt_CWrapperFunctionResultInit(&Tmp); 59 std::swap(R, Tmp); 60 return Tmp; 61 } 62 63 /// Get a pointer to the data contained in this instance. 64 const char *data() const { return __orc_rt_CWrapperFunctionResultData(&R); } 65 66 /// Returns the size of the data contained in this instance. 67 size_t size() const { return __orc_rt_CWrapperFunctionResultSize(&R); } 68 69 /// Returns true if this value is equivalent to a default-constructed 70 /// WrapperFunctionResult. 71 bool empty() const { return __orc_rt_CWrapperFunctionResultEmpty(&R); } 72 73 /// Create a WrapperFunctionResult with the given size and return a pointer 74 /// to the underlying memory. 75 static char *allocate(WrapperFunctionResult &R, size_t Size) { 76 __orc_rt_DisposeCWrapperFunctionResult(&R.R); 77 __orc_rt_CWrapperFunctionResultInit(&R.R); 78 return __orc_rt_CWrapperFunctionResultAllocate(&R.R, Size); 79 } 80 81 /// Copy from the given char range. 82 static WrapperFunctionResult copyFrom(const char *Source, size_t Size) { 83 return __orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size); 84 } 85 86 /// Copy from the given null-terminated string (includes the null-terminator). 87 static WrapperFunctionResult copyFrom(const char *Source) { 88 return __orc_rt_CreateCWrapperFunctionResultFromString(Source); 89 } 90 91 /// Copy from the given std::string (includes the null terminator). 92 static WrapperFunctionResult copyFrom(const std::string &Source) { 93 return copyFrom(Source.c_str()); 94 } 95 96 /// Create an out-of-band error by copying the given string. 97 static WrapperFunctionResult createOutOfBandError(const char *Msg) { 98 return __orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg); 99 } 100 101 /// Create an out-of-band error by copying the given string. 102 static WrapperFunctionResult createOutOfBandError(const std::string &Msg) { 103 return createOutOfBandError(Msg.c_str()); 104 } 105 106 /// If this value is an out-of-band error then this returns the error message, 107 /// otherwise returns nullptr. 108 const char *getOutOfBandError() const { 109 return __orc_rt_CWrapperFunctionResultGetOutOfBandError(&R); 110 } 111 112 private: 113 __orc_rt_CWrapperFunctionResult R; 114 }; 115 116 namespace detail { 117 118 template <typename SPSArgListT, typename... ArgTs> 119 Expected<WrapperFunctionResult> 120 serializeViaSPSToWrapperFunctionResult(const ArgTs &...Args) { 121 WrapperFunctionResult Result; 122 char *DataPtr = 123 WrapperFunctionResult::allocate(Result, SPSArgListT::size(Args...)); 124 SPSOutputBuffer OB(DataPtr, Result.size()); 125 if (!SPSArgListT::serialize(OB, Args...)) 126 return make_error<StringError>( 127 "Error serializing arguments to blob in call"); 128 return Result; 129 } 130 131 template <typename WrapperFunctionImplT, 132 template <typename> class ResultSerializer, typename... SPSTagTs> 133 class WrapperFunctionHandlerHelper 134 : public WrapperFunctionHandlerHelper< 135 decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()), 136 ResultSerializer, SPSTagTs...> {}; 137 138 template <typename RetT, typename... ArgTs, 139 template <typename> class ResultSerializer, typename... SPSTagTs> 140 class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 141 SPSTagTs...> { 142 public: 143 using ArgTuple = std::tuple<std::decay_t<ArgTs>...>; 144 using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>; 145 146 template <typename HandlerT> 147 static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData, 148 size_t ArgSize) { 149 ArgTuple Args; 150 if (!deserialize(ArgData, ArgSize, Args, ArgIndices{})) 151 return WrapperFunctionResult::createOutOfBandError( 152 "Could not deserialize arguments for wrapper function call"); 153 154 if (auto Result = ResultSerializer<RetT>::serialize( 155 call(std::forward<HandlerT>(H), Args, ArgIndices{}))) 156 return std::move(*Result); 157 else 158 return WrapperFunctionResult::createOutOfBandError( 159 toString(Result.takeError())); 160 } 161 162 private: 163 template <std::size_t... I> 164 static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args, 165 std::index_sequence<I...>) { 166 SPSInputBuffer IB(ArgData, ArgSize); 167 return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...); 168 } 169 170 template <typename HandlerT, std::size_t... I> 171 static decltype(auto) call(HandlerT &&H, ArgTuple &Args, 172 std::index_sequence<I...>) { 173 return std::forward<HandlerT>(H)(std::get<I>(Args)...); 174 } 175 }; 176 177 // Map function references to function types. 178 template <typename RetT, typename... ArgTs, 179 template <typename> class ResultSerializer, typename... SPSTagTs> 180 class WrapperFunctionHandlerHelper<RetT (&)(ArgTs...), ResultSerializer, 181 SPSTagTs...> 182 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 183 SPSTagTs...> {}; 184 185 // Map non-const member function types to function types. 186 template <typename ClassT, typename RetT, typename... ArgTs, 187 template <typename> class ResultSerializer, typename... SPSTagTs> 188 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer, 189 SPSTagTs...> 190 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 191 SPSTagTs...> {}; 192 193 // Map const member function types to function types. 194 template <typename ClassT, typename RetT, typename... ArgTs, 195 template <typename> class ResultSerializer, typename... SPSTagTs> 196 class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const, 197 ResultSerializer, SPSTagTs...> 198 : public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer, 199 SPSTagTs...> {}; 200 201 template <typename SPSRetTagT, typename RetT> class ResultSerializer { 202 public: 203 static Expected<WrapperFunctionResult> serialize(RetT Result) { 204 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 205 Result); 206 } 207 }; 208 209 template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> { 210 public: 211 static Expected<WrapperFunctionResult> serialize(Error Err) { 212 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 213 toSPSSerializable(std::move(Err))); 214 } 215 }; 216 217 template <typename SPSRetTagT, typename T> 218 class ResultSerializer<SPSRetTagT, Expected<T>> { 219 public: 220 static Expected<WrapperFunctionResult> serialize(Expected<T> E) { 221 return serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSRetTagT>>( 222 toSPSSerializable(std::move(E))); 223 } 224 }; 225 226 template <typename SPSRetTagT, typename RetT> class ResultDeserializer { 227 public: 228 static void makeSafe(RetT &Result) {} 229 230 static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) { 231 SPSInputBuffer IB(ArgData, ArgSize); 232 if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result)) 233 return make_error<StringError>( 234 "Error deserializing return value from blob in call"); 235 return Error::success(); 236 } 237 }; 238 239 template <> class ResultDeserializer<SPSError, Error> { 240 public: 241 static void makeSafe(Error &Err) { cantFail(std::move(Err)); } 242 243 static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) { 244 SPSInputBuffer IB(ArgData, ArgSize); 245 SPSSerializableError BSE; 246 if (!SPSArgList<SPSError>::deserialize(IB, BSE)) 247 return make_error<StringError>( 248 "Error deserializing return value from blob in call"); 249 Err = fromSPSSerializable(std::move(BSE)); 250 return Error::success(); 251 } 252 }; 253 254 template <typename SPSTagT, typename T> 255 class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> { 256 public: 257 static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); } 258 259 static Error deserialize(Expected<T> &E, const char *ArgData, 260 size_t ArgSize) { 261 SPSInputBuffer IB(ArgData, ArgSize); 262 SPSSerializableExpected<T> BSE; 263 if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE)) 264 return make_error<StringError>( 265 "Error deserializing return value from blob in call"); 266 E = fromSPSSerializable(std::move(BSE)); 267 return Error::success(); 268 } 269 }; 270 271 } // end namespace detail 272 273 template <typename SPSSignature> class WrapperFunction; 274 275 template <typename SPSRetTagT, typename... SPSTagTs> 276 class WrapperFunction<SPSRetTagT(SPSTagTs...)> { 277 private: 278 template <typename RetT> 279 using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>; 280 281 public: 282 template <typename RetT, typename... ArgTs> 283 static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) { 284 285 // RetT might be an Error or Expected value. Set the checked flag now: 286 // we don't want the user to have to check the unused result if this 287 // operation fails. 288 detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result); 289 290 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 291 return make_error<StringError>("__orc_jtjit_dispatch_ctx not set"); 292 if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx)) 293 return make_error<StringError>("__orc_jtjit_dispatch not set"); 294 295 auto ArgBuffer = 296 detail::serializeViaSPSToWrapperFunctionResult<SPSArgList<SPSTagTs...>>( 297 Args...); 298 if (!ArgBuffer) 299 return ArgBuffer.takeError(); 300 301 WrapperFunctionResult ResultBuffer = 302 __orc_rt_jit_dispatch(&__orc_rt_jit_dispatch_ctx, FnTag, 303 ArgBuffer->data(), ArgBuffer->size()); 304 if (auto ErrMsg = ResultBuffer.getOutOfBandError()) 305 return make_error<StringError>(ErrMsg); 306 307 return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize( 308 Result, ResultBuffer.data(), ResultBuffer.size()); 309 } 310 311 template <typename HandlerT> 312 static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize, 313 HandlerT &&Handler) { 314 using WFHH = 315 detail::WrapperFunctionHandlerHelper<HandlerT, ResultSerializer, 316 SPSTagTs...>; 317 return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize); 318 } 319 320 private: 321 template <typename T> static const T &makeSerializable(const T &Value) { 322 return Value; 323 } 324 325 static detail::SPSSerializableError makeSerializable(Error Err) { 326 return detail::toSPSSerializable(std::move(Err)); 327 } 328 329 template <typename T> 330 static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) { 331 return detail::toSPSSerializable(std::move(E)); 332 } 333 }; 334 335 template <typename... SPSTagTs> 336 class WrapperFunction<void(SPSTagTs...)> 337 : private WrapperFunction<SPSEmpty(SPSTagTs...)> { 338 public: 339 template <typename... ArgTs> 340 static Error call(const void *FnTag, const ArgTs &...Args) { 341 SPSEmpty BE; 342 return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...); 343 } 344 345 using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle; 346 }; 347 348 } // end namespace __orc_rt 349 350 #endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H 351